Ensure consistency with error-handling across all handlers. (#11599)

This commit is contained in:
Mathew Estafanous 2022-01-05 12:11:03 -05:00 committed by GitHub
parent 6d0a73c0eb
commit 0fdd1318e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 140 deletions

View File

@ -16,23 +16,22 @@ type aclBootstrapResponse struct {
structs.ACLToken structs.ACLToken
} }
var aclDisabled = UnauthorizedError{Reason: "ACL support disabled"}
// checkACLDisabled will return a standard response if ACLs are disabled. This // checkACLDisabled will return a standard response if ACLs are disabled. This
// returns true if they are disabled and we should not continue. // returns true if they are disabled and we should not continue.
func (s *HTTPHandlers) checkACLDisabled(resp http.ResponseWriter, _req *http.Request) bool { func (s *HTTPHandlers) checkACLDisabled() bool {
if s.agent.config.ACLsEnabled { if s.agent.config.ACLsEnabled {
return false return false
} }
resp.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(resp, "ACL support disabled")
return true return true
} }
// ACLBootstrap is used to perform a one-time ACL bootstrap operation on // ACLBootstrap is used to perform a one-time ACL bootstrap operation on
// a cluster to get the first management token. // a cluster to get the first management token.
func (s *HTTPHandlers) ACLBootstrap(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLBootstrap(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
args := structs.DCSpecificRequest{ args := structs.DCSpecificRequest{
@ -42,9 +41,7 @@ func (s *HTTPHandlers) ACLBootstrap(resp http.ResponseWriter, req *http.Request)
err := s.agent.RPC("ACL.BootstrapTokens", &args, &out) err := s.agent.RPC("ACL.BootstrapTokens", &args, &out)
if err != nil { if err != nil {
if strings.Contains(err.Error(), structs.ACLBootstrapNotAllowedErr.Error()) { if strings.Contains(err.Error(), structs.ACLBootstrapNotAllowedErr.Error()) {
resp.WriteHeader(http.StatusForbidden) return nil, acl.PermissionDeniedError{Cause: err.Error()}
fmt.Fprint(resp, acl.PermissionDeniedError{Cause: err.Error()}.Error())
return nil, nil
} else { } else {
return nil, err return nil, err
} }
@ -53,8 +50,8 @@ func (s *HTTPHandlers) ACLBootstrap(resp http.ResponseWriter, req *http.Request)
} }
func (s *HTTPHandlers) ACLReplicationStatus(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLReplicationStatus(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
// Note that we do not forward to the ACL DC here. This is a query for // Note that we do not forward to the ACL DC here. This is a query for
@ -74,8 +71,8 @@ func (s *HTTPHandlers) ACLReplicationStatus(resp http.ResponseWriter, req *http.
} }
func (s *HTTPHandlers) ACLPolicyList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLPolicyList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
var args structs.ACLPolicyListRequest var args structs.ACLPolicyListRequest
@ -105,8 +102,8 @@ func (s *HTTPHandlers) ACLPolicyList(resp http.ResponseWriter, req *http.Request
} }
func (s *HTTPHandlers) ACLPolicyCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLPolicyCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
var fn func(resp http.ResponseWriter, req *http.Request, policyID string) (interface{}, error) var fn func(resp http.ResponseWriter, req *http.Request, policyID string) (interface{}, error)
@ -166,8 +163,8 @@ func (s *HTTPHandlers) ACLPolicyRead(resp http.ResponseWriter, req *http.Request
} }
func (s *HTTPHandlers) ACLPolicyReadByName(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLPolicyReadByName(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
policyName := strings.TrimPrefix(req.URL.Path, "/v1/acl/policy/name/") policyName := strings.TrimPrefix(req.URL.Path, "/v1/acl/policy/name/")
@ -183,8 +180,8 @@ func (s *HTTPHandlers) ACLPolicyReadByID(resp http.ResponseWriter, req *http.Req
} }
func (s *HTTPHandlers) ACLPolicyCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLPolicyCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
return s.aclPolicyWriteInternal(resp, req, "", true) return s.aclPolicyWriteInternal(resp, req, "", true)
@ -248,8 +245,8 @@ func (s *HTTPHandlers) ACLPolicyDelete(resp http.ResponseWriter, req *http.Reque
} }
func (s *HTTPHandlers) ACLTokenList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLTokenList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
args := &structs.ACLTokenListRequest{ args := &structs.ACLTokenListRequest{
@ -285,8 +282,8 @@ func (s *HTTPHandlers) ACLTokenList(resp http.ResponseWriter, req *http.Request)
} }
func (s *HTTPHandlers) ACLTokenCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLTokenCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
var fn func(resp http.ResponseWriter, req *http.Request, tokenID string) (interface{}, error) var fn func(resp http.ResponseWriter, req *http.Request, tokenID string) (interface{}, error)
@ -318,8 +315,8 @@ func (s *HTTPHandlers) ACLTokenCRUD(resp http.ResponseWriter, req *http.Request)
} }
func (s *HTTPHandlers) ACLTokenSelf(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLTokenSelf(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
args := structs.ACLTokenGetRequest{ args := structs.ACLTokenGetRequest{
@ -351,8 +348,8 @@ func (s *HTTPHandlers) ACLTokenSelf(resp http.ResponseWriter, req *http.Request)
} }
func (s *HTTPHandlers) ACLTokenCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLTokenCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
return s.aclTokenSetInternal(req, "", true) return s.aclTokenSetInternal(req, "", true)
@ -442,8 +439,8 @@ func (s *HTTPHandlers) ACLTokenDelete(resp http.ResponseWriter, req *http.Reques
} }
func (s *HTTPHandlers) ACLTokenClone(resp http.ResponseWriter, req *http.Request, tokenID string) (interface{}, error) { func (s *HTTPHandlers) ACLTokenClone(resp http.ResponseWriter, req *http.Request, tokenID string) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
args := structs.ACLTokenSetRequest{ args := structs.ACLTokenSetRequest{
@ -471,8 +468,8 @@ func (s *HTTPHandlers) ACLTokenClone(resp http.ResponseWriter, req *http.Request
} }
func (s *HTTPHandlers) ACLRoleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLRoleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
var args structs.ACLRoleListRequest var args structs.ACLRoleListRequest
@ -504,8 +501,8 @@ func (s *HTTPHandlers) ACLRoleList(resp http.ResponseWriter, req *http.Request)
} }
func (s *HTTPHandlers) ACLRoleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLRoleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
var fn func(resp http.ResponseWriter, req *http.Request, roleID string) (interface{}, error) var fn func(resp http.ResponseWriter, req *http.Request, roleID string) (interface{}, error)
@ -533,8 +530,8 @@ func (s *HTTPHandlers) ACLRoleCRUD(resp http.ResponseWriter, req *http.Request)
} }
func (s *HTTPHandlers) ACLRoleReadByName(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLRoleReadByName(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
roleName := strings.TrimPrefix(req.URL.Path, "/v1/acl/role/name/") roleName := strings.TrimPrefix(req.URL.Path, "/v1/acl/role/name/")
@ -581,8 +578,8 @@ func (s *HTTPHandlers) ACLRoleRead(resp http.ResponseWriter, req *http.Request,
} }
func (s *HTTPHandlers) ACLRoleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLRoleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
return s.ACLRoleWrite(resp, req, "") return s.ACLRoleWrite(resp, req, "")
@ -634,8 +631,8 @@ func (s *HTTPHandlers) ACLRoleDelete(resp http.ResponseWriter, req *http.Request
} }
func (s *HTTPHandlers) ACLBindingRuleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLBindingRuleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
var args structs.ACLBindingRuleListRequest var args structs.ACLBindingRuleListRequest
@ -668,8 +665,8 @@ func (s *HTTPHandlers) ACLBindingRuleList(resp http.ResponseWriter, req *http.Re
} }
func (s *HTTPHandlers) ACLBindingRuleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLBindingRuleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
var fn func(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) var fn func(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error)
@ -728,8 +725,8 @@ func (s *HTTPHandlers) ACLBindingRuleRead(resp http.ResponseWriter, req *http.Re
} }
func (s *HTTPHandlers) ACLBindingRuleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLBindingRuleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
return s.ACLBindingRuleWrite(resp, req, "") return s.ACLBindingRuleWrite(resp, req, "")
@ -781,8 +778,8 @@ func (s *HTTPHandlers) ACLBindingRuleDelete(resp http.ResponseWriter, req *http.
} }
func (s *HTTPHandlers) ACLAuthMethodList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLAuthMethodList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
var args structs.ACLAuthMethodListRequest var args structs.ACLAuthMethodListRequest
@ -812,8 +809,8 @@ func (s *HTTPHandlers) ACLAuthMethodList(resp http.ResponseWriter, req *http.Req
} }
func (s *HTTPHandlers) ACLAuthMethodCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLAuthMethodCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
var fn func(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) var fn func(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error)
@ -872,8 +869,8 @@ func (s *HTTPHandlers) ACLAuthMethodRead(resp http.ResponseWriter, req *http.Req
} }
func (s *HTTPHandlers) ACLAuthMethodCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLAuthMethodCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
return s.ACLAuthMethodWrite(resp, req, "") return s.ACLAuthMethodWrite(resp, req, "")
@ -928,8 +925,8 @@ func (s *HTTPHandlers) ACLAuthMethodDelete(resp http.ResponseWriter, req *http.R
} }
func (s *HTTPHandlers) ACLLogin(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLLogin(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
args := &structs.ACLLoginRequest{ args := &structs.ACLLoginRequest{
@ -954,8 +951,8 @@ func (s *HTTPHandlers) ACLLogin(resp http.ResponseWriter, req *http.Request) (in
} }
func (s *HTTPHandlers) ACLLogout(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) ACLLogout(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
args := structs.ACLLogoutRequest{ args := structs.ACLLogoutRequest{
@ -1014,8 +1011,8 @@ func (s *HTTPHandlers) ACLAuthorize(resp http.ResponseWriter, req *http.Request)
// policy. // policy.
const maxRequests = 64 const maxRequests = 64
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, aclDisabled
} }
request := structs.RemoteACLAuthorizationRequest{ request := structs.RemoteACLAuthorizationRequest{

View File

@ -70,10 +70,8 @@ func TestACL_Disabled_Response(t *testing.T) {
req, _ := http.NewRequest("PUT", "/should/not/care", nil) req, _ := http.NewRequest("PUT", "/should/not/care", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := tt.fn(resp, req) obj, err := tt.fn(resp, req)
require.NoError(t, err)
require.Nil(t, obj) require.Nil(t, obj)
require.Equal(t, http.StatusUnauthorized, resp.Code) require.ErrorIs(t, err, UnauthorizedError{Reason: "ACL support disabled"})
require.Contains(t, resp.Body.String(), "ACL support disabled")
}) })
} }
} }
@ -119,9 +117,6 @@ func TestACL_Bootstrap(t *testing.T) {
if tt.token && err != nil { if tt.token && err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if got, want := resp.Code, tt.code; got != want {
t.Fatalf("got %d want %d", got, want)
}
if tt.token { if tt.token {
wrap, ok := out.(*aclBootstrapResponse) wrap, ok := out.(*aclBootstrapResponse)
if !ok { if !ok {

View File

@ -155,9 +155,11 @@ func (s *HTTPHandlers) AgentMetrics(resp http.ResponseWriter, req *http.Request)
} }
if enablePrometheusOutput(req) { if enablePrometheusOutput(req) {
if s.agent.config.Telemetry.PrometheusOpts.Expiration < 1 { if s.agent.config.Telemetry.PrometheusOpts.Expiration < 1 {
resp.WriteHeader(http.StatusUnsupportedMediaType) return nil, CodeWithPayloadError{
fmt.Fprint(resp, "Prometheus is not enabled since its retention time is not positive") StatusCode: http.StatusUnsupportedMediaType,
return nil, nil Reason: "Prometheus is not enabled since its retention time is not positive",
ContentType: "text/plain",
}
} }
handlerOptions := promhttp.HandlerOpts{ handlerOptions := promhttp.HandlerOpts{
ErrorLog: s.agent.logger.StandardLogger(&hclog.StandardLoggerOptions{ ErrorLog: s.agent.logger.StandardLogger(&hclog.StandardLoggerOptions{
@ -423,11 +425,7 @@ func (s *HTTPHandlers) AgentService(resp http.ResponseWriter, req *http.Request)
svcState := s.agent.State.ServiceState(sid) svcState := s.agent.State.ServiceState(sid)
if svcState == nil { if svcState == nil {
resp.WriteHeader(http.StatusNotFound) return "", nil, NotFoundError{Reason: fmt.Sprintf("unknown service ID: %s", sid.String())}
fmt.Fprintf(resp,
"Unknown service ID %q. Ensure that the service ID is passed, not the service name.",
sid.String())
return "", nil, nil
} }
svc := svcState.Service svc := svcState.Service
@ -557,9 +555,7 @@ func (s *HTTPHandlers) AgentMembers(resp http.ResponseWriter, req *http.Request)
// key are ok, otherwise the argument doesn't apply to // key are ok, otherwise the argument doesn't apply to
// the WAN. // the WAN.
default: default:
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: "Cannot provide a segment with wan=true"}
fmt.Fprint(resp, "Cannot provide a segment with wan=true")
return nil, nil
} }
} }
@ -735,16 +731,16 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re
} }
if err := decodeBody(req.Body, &args); err != nil { if err := decodeBody(req.Body, &args); err != nil {
return nil, BadRequestError{fmt.Sprintf("Request decode failed: %v", err)} return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
// Verify the check has a name. // Verify the check has a name.
if args.Name == "" { if args.Name == "" {
return nil, BadRequestError{"Missing check name"} return nil, BadRequestError{Reason: "Missing check name"}
} }
if args.Status != "" && !structs.ValidStatus(args.Status) { if args.Status != "" && !structs.ValidStatus(args.Status) {
return nil, BadRequestError{"Bad check status"} return nil, BadRequestError{Reason: "Bad check status"}
} }
authz, err := s.agent.delegate.ResolveTokenAndDefaultMeta(token, &args.EnterpriseMeta, nil) authz, err := s.agent.delegate.ResolveTokenAndDefaultMeta(token, &args.EnterpriseMeta, nil)
@ -763,15 +759,15 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re
chkType := args.CheckType() chkType := args.CheckType()
err = chkType.Validate() err = chkType.Validate()
if err != nil { if err != nil {
return nil, BadRequestError{fmt.Sprintf("Invalid check: %v", err)} return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check: %v", err)}
} }
// Store the type of check based on the definition // Store the type of check based on the definition
health.Type = chkType.Type() health.Type = chkType.Type()
if health.ServiceID != "" { if health.ServiceID != "" {
cid := health.CompoundServiceID()
// fixup the service name so that vetCheckRegister requires the right ACLs // fixup the service name so that vetCheckRegister requires the right ACLs
cid := health.CompoundServiceID()
service := s.agent.State.Service(cid) service := s.agent.State.Service(cid)
if service != nil { if service != nil {
health.ServiceName = service.Service health.ServiceName = service.Service
@ -881,9 +877,7 @@ type checkUpdate struct {
func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
var update checkUpdate var update checkUpdate
if err := decodeBody(req.Body, &update); err != nil { if err := decodeBody(req.Body, &update); err != nil {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
} }
switch update.Status { switch update.Status {
@ -891,9 +885,7 @@ func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Requ
case api.HealthWarning: case api.HealthWarning:
case api.HealthCritical: case api.HealthCritical:
default: default:
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check status: '%s'", update.Status)}
fmt.Fprintf(resp, "Invalid check status: '%s'", update.Status)
return nil, nil
} }
ID, err := getPathSuffixUnescaped(req.URL.Path, "/v1/agent/check/update/") ID, err := getPathSuffixUnescaped(req.URL.Path, "/v1/agent/check/update/")
@ -1121,24 +1113,18 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
} }
if err := decodeBody(req.Body, &args); err != nil { if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
} }
// Verify the service has a name. // Verify the service has a name.
if args.Name == "" { if args.Name == "" {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: "Missing service name"}
fmt.Fprint(resp, "Missing service name")
return nil, nil
} }
// Check the service address here and in the catalog RPC endpoint // Check the service address here and in the catalog RPC endpoint
// since service registration isn't synchronous. // since service registration isn't synchronous.
if ipaddr.IsAny(args.Address) { if ipaddr.IsAny(args.Address) {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: "Invalid service address"}
fmt.Fprintf(resp, "Invalid service address")
return nil, nil
} }
var token string var token string
@ -1157,37 +1143,27 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
ns := args.NodeService() ns := args.NodeService()
if ns.Weights != nil { if ns.Weights != nil {
if err := structs.ValidateWeights(ns.Weights); err != nil { if err := structs.ValidateWeights(ns.Weights); err != nil {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: fmt.Sprintf("Invalid Weights: %v", err)}
fmt.Fprint(resp, fmt.Errorf("Invalid Weights: %v", err))
return nil, nil
} }
} }
if err := structs.ValidateServiceMetadata(ns.Kind, ns.Meta, false); err != nil { if err := structs.ValidateServiceMetadata(ns.Kind, ns.Meta, false); err != nil {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: fmt.Sprintf("Invalid Service Meta: %v", err)}
fmt.Fprint(resp, fmt.Errorf("Invalid Service Meta: %v", err))
return nil, nil
} }
// Run validation. This is the same validation that would happen on // Run validation. This is the same validation that would happen on
// the catalog endpoint so it helps ensure the sync will work properly. // the catalog endpoint so it helps ensure the sync will work properly.
if err := ns.Validate(); err != nil { if err := ns.Validate(); err != nil {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: fmt.Sprintf("Validation failed: %v", err.Error())}
fmt.Fprint(resp, err.Error())
return nil, nil
} }
// Verify the check type. // Verify the check type.
chkTypes, err := args.CheckTypes() chkTypes, err := args.CheckTypes()
if err != nil { if err != nil {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check: %v", err)}
fmt.Fprint(resp, fmt.Errorf("Invalid check: %v", err))
return nil, nil
} }
for _, check := range chkTypes { for _, check := range chkTypes {
if check.Status != "" && !structs.ValidStatus(check.Status) { if check.Status != "" && !structs.ValidStatus(check.Status) {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: "Status for checks must 'passing', 'warning', 'critical'"}
fmt.Fprint(resp, "Status for checks must 'passing', 'warning', 'critical'")
return nil, nil
} }
} }
@ -1221,9 +1197,7 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
} }
if sidecar != nil { if sidecar != nil {
if err := sidecar.Validate(); err != nil { if err := sidecar.Validate(); err != nil {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: fmt.Sprintf("Failed Validation: %v", err.Error())}
fmt.Fprint(resp, err.Error())
return nil, nil
} }
// Make sure we are allowed to register the sidecar using the token // Make sure we are allowed to register the sidecar using the token
// specified (might be specific to sidecar or the same one as the overall // specified (might be specific to sidecar or the same one as the overall
@ -1324,25 +1298,19 @@ func (s *HTTPHandlers) AgentServiceMaintenance(resp http.ResponseWriter, req *ht
sid := structs.NewServiceID(serviceID, nil) sid := structs.NewServiceID(serviceID, nil)
if sid.ID == "" { if sid.ID == "" {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: "Missing service ID"}
fmt.Fprint(resp, "Missing service ID")
return nil, nil
} }
// Ensure we have some action // Ensure we have some action
params := req.URL.Query() params := req.URL.Query()
if _, ok := params["enable"]; !ok { if _, ok := params["enable"]; !ok {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: "Missing value for enable"}
fmt.Fprint(resp, "Missing value for enable")
return nil, nil
} }
raw := params.Get("enable") raw := params.Get("enable")
enable, err := strconv.ParseBool(raw) enable, err := strconv.ParseBool(raw)
if err != nil { if err != nil {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: fmt.Sprintf("Invalid value for enable: %q", raw)}
fmt.Fprintf(resp, "Invalid value for enable: %q", raw)
return nil, nil
} }
// Get the provided token, if any, and vet against any ACL policies. // Get the provided token, if any, and vet against any ACL policies.
@ -1371,15 +1339,11 @@ func (s *HTTPHandlers) AgentServiceMaintenance(resp http.ResponseWriter, req *ht
if enable { if enable {
reason := params.Get("reason") reason := params.Get("reason")
if err = s.agent.EnableServiceMaintenance(sid, reason, token); err != nil { if err = s.agent.EnableServiceMaintenance(sid, reason, token); err != nil {
resp.WriteHeader(http.StatusNotFound) return nil, NotFoundError{Reason: err.Error()}
fmt.Fprint(resp, err.Error())
return nil, nil
} }
} else { } else {
if err = s.agent.DisableServiceMaintenance(sid); err != nil { if err = s.agent.DisableServiceMaintenance(sid); err != nil {
resp.WriteHeader(http.StatusNotFound) return nil, NotFoundError{Reason: err.Error()}
fmt.Fprint(resp, err.Error())
return nil, nil
} }
} }
s.syncChanges() s.syncChanges()
@ -1390,17 +1354,13 @@ func (s *HTTPHandlers) AgentNodeMaintenance(resp http.ResponseWriter, req *http.
// Ensure we have some action // Ensure we have some action
params := req.URL.Query() params := req.URL.Query()
if _, ok := params["enable"]; !ok { if _, ok := params["enable"]; !ok {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: "Missing value for enable"}
fmt.Fprint(resp, "Missing value for enable")
return nil, nil
} }
raw := params.Get("enable") raw := params.Get("enable")
enable, err := strconv.ParseBool(raw) enable, err := strconv.ParseBool(raw)
if err != nil { if err != nil {
resp.WriteHeader(http.StatusBadRequest) return nil, BadRequestError{Reason: fmt.Sprintf("Invalid value for enable: %q", raw)}
fmt.Fprintf(resp, "Invalid value for enable: %q", raw)
return nil, nil
} }
// Get the provided token, if any, and vet against any ACL policies. // Get the provided token, if any, and vet against any ACL policies.
@ -1507,8 +1467,8 @@ func (s *HTTPHandlers) AgentMonitor(resp http.ResponseWriter, req *http.Request)
} }
func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) { if s.checkACLDisabled() {
return nil, nil return nil, UnauthorizedError{Reason: "ACL support disabled"}
} }
// Fetch the ACL token, if any, and enforce agent policy. // Fetch the ACL token, if any, and enforce agent policy.

View File

@ -662,7 +662,7 @@ func TestAgent_Service(t *testing.T) {
{ {
name: "err: non-existent proxy", name: "err: non-existent proxy",
url: "/v1/agent/service/nope", url: "/v1/agent/service/nope",
wantCode: 404, wantErr: "unknown service ID: nope",
}, },
{ {
name: "err: bad ACL for service", name: "err: bad ACL for service",
@ -3784,9 +3784,6 @@ func testAgent_RegisterService_InvalidAddress(t *testing.T, extraHCL string) {
if got, want := resp.Code, 400; got != want { if got, want := resp.Code, 400; got != want {
t.Fatalf("got code %d want %d", got, want) t.Fatalf("got code %d want %d", got, want)
} }
if got, want := resp.Body.String(), "Invalid service address"; got != want {
t.Fatalf("got body %q want %q", got, want)
}
}) })
} }
} }

View File

@ -69,6 +69,15 @@ func (e NotFoundError) Error() string {
return e.Reason return e.Reason
} }
// UnauthorizedError should be returned by a handler when the request lacks valid authorization.
type UnauthorizedError struct {
Reason string
}
func (e UnauthorizedError) Error() string {
return e.Reason
}
// CodeWithPayloadError allow returning non HTTP 200 // CodeWithPayloadError allow returning non HTTP 200
// Error codes while not returning PlainText payload // Error codes while not returning PlainText payload
type CodeWithPayloadError struct { type CodeWithPayloadError struct {
@ -241,7 +250,8 @@ func (s *HTTPHandlers) handler(enableDebug bool) http.Handler {
// If enableDebug is not set, and ACLs are disabled, write // If enableDebug is not set, and ACLs are disabled, write
// an unauthorized response // an unauthorized response
if !enableDebug && s.checkACLDisabled(resp, req) { if !enableDebug && s.checkACLDisabled() {
resp.WriteHeader(http.StatusUnauthorized)
return return
} }
@ -423,6 +433,11 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
return ok return ok
} }
isUnauthorized := func(err error) bool {
_, ok := err.(UnauthorizedError)
return ok
}
isTooManyRequests := func(err error) bool { isTooManyRequests := func(err error) bool {
// Sadness net/rpc can't do nice typed errors so this is all we got // Sadness net/rpc can't do nice typed errors so this is all we got
return err.Error() == consul.ErrRateLimited.Error() return err.Error() == consul.ErrRateLimited.Error()
@ -467,6 +482,9 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
case isNotFound(err): case isNotFound(err):
resp.WriteHeader(http.StatusNotFound) resp.WriteHeader(http.StatusNotFound)
fmt.Fprint(resp, err.Error()) fmt.Fprint(resp, err.Error())
case isUnauthorized(err):
resp.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(resp, err.Error())
case isTooManyRequests(err): case isTooManyRequests(err):
resp.WriteHeader(http.StatusTooManyRequests) resp.WriteHeader(http.StatusTooManyRequests)
fmt.Fprint(resp, err.Error()) fmt.Fprint(resp, err.Error())