mirror of https://github.com/status-im/consul.git
Ensure consistency with error-handling across all handlers. (#11599)
This commit is contained in:
parent
6d0a73c0eb
commit
0fdd1318e9
|
@ -16,23 +16,22 @@ type aclBootstrapResponse struct {
|
|||
structs.ACLToken
|
||||
}
|
||||
|
||||
var aclDisabled = UnauthorizedError{Reason: "ACL support disabled"}
|
||||
|
||||
// checkACLDisabled will return a standard response if ACLs are disabled. This
|
||||
// returns true if they are disabled and we should not continue.
|
||||
func (s *HTTPHandlers) checkACLDisabled(resp http.ResponseWriter, _req *http.Request) bool {
|
||||
func (s *HTTPHandlers) checkACLDisabled() bool {
|
||||
if s.agent.config.ACLsEnabled {
|
||||
return false
|
||||
}
|
||||
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
fmt.Fprint(resp, "ACL support disabled")
|
||||
return true
|
||||
}
|
||||
|
||||
// ACLBootstrap is used to perform a one-time ACL bootstrap operation on
|
||||
// a cluster to get the first management token.
|
||||
func (s *HTTPHandlers) ACLBootstrap(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
args := structs.DCSpecificRequest{
|
||||
|
@ -42,9 +41,7 @@ func (s *HTTPHandlers) ACLBootstrap(resp http.ResponseWriter, req *http.Request)
|
|||
err := s.agent.RPC("ACL.BootstrapTokens", &args, &out)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), structs.ACLBootstrapNotAllowedErr.Error()) {
|
||||
resp.WriteHeader(http.StatusForbidden)
|
||||
fmt.Fprint(resp, acl.PermissionDeniedError{Cause: err.Error()}.Error())
|
||||
return nil, nil
|
||||
return nil, acl.PermissionDeniedError{Cause: err.Error()}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -53,8 +50,8 @@ func (s *HTTPHandlers) ACLBootstrap(resp http.ResponseWriter, req *http.Request)
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLReplicationStatus(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
// Note that we do not forward to the ACL DC here. This is a query for
|
||||
|
@ -74,8 +71,8 @@ func (s *HTTPHandlers) ACLReplicationStatus(resp http.ResponseWriter, req *http.
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLPolicyList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
var args structs.ACLPolicyListRequest
|
||||
|
@ -105,8 +102,8 @@ func (s *HTTPHandlers) ACLPolicyList(resp http.ResponseWriter, req *http.Request
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLPolicyCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
var fn func(resp http.ResponseWriter, req *http.Request, policyID string) (interface{}, error)
|
||||
|
@ -166,8 +163,8 @@ func (s *HTTPHandlers) ACLPolicyRead(resp http.ResponseWriter, req *http.Request
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLPolicyReadByName(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
policyName := strings.TrimPrefix(req.URL.Path, "/v1/acl/policy/name/")
|
||||
|
@ -183,8 +180,8 @@ func (s *HTTPHandlers) ACLPolicyReadByID(resp http.ResponseWriter, req *http.Req
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLPolicyCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
return s.aclPolicyWriteInternal(resp, req, "", true)
|
||||
|
@ -248,8 +245,8 @@ func (s *HTTPHandlers) ACLPolicyDelete(resp http.ResponseWriter, req *http.Reque
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLTokenList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
args := &structs.ACLTokenListRequest{
|
||||
|
@ -285,8 +282,8 @@ func (s *HTTPHandlers) ACLTokenList(resp http.ResponseWriter, req *http.Request)
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLTokenCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
var fn func(resp http.ResponseWriter, req *http.Request, tokenID string) (interface{}, error)
|
||||
|
@ -318,8 +315,8 @@ func (s *HTTPHandlers) ACLTokenCRUD(resp http.ResponseWriter, req *http.Request)
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLTokenSelf(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
args := structs.ACLTokenGetRequest{
|
||||
|
@ -351,8 +348,8 @@ func (s *HTTPHandlers) ACLTokenSelf(resp http.ResponseWriter, req *http.Request)
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLTokenCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
return s.aclTokenSetInternal(req, "", true)
|
||||
|
@ -442,8 +439,8 @@ func (s *HTTPHandlers) ACLTokenDelete(resp http.ResponseWriter, req *http.Reques
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLTokenClone(resp http.ResponseWriter, req *http.Request, tokenID string) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
args := structs.ACLTokenSetRequest{
|
||||
|
@ -471,8 +468,8 @@ func (s *HTTPHandlers) ACLTokenClone(resp http.ResponseWriter, req *http.Request
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLRoleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
var args structs.ACLRoleListRequest
|
||||
|
@ -504,8 +501,8 @@ func (s *HTTPHandlers) ACLRoleList(resp http.ResponseWriter, req *http.Request)
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLRoleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
var fn func(resp http.ResponseWriter, req *http.Request, roleID string) (interface{}, error)
|
||||
|
@ -533,8 +530,8 @@ func (s *HTTPHandlers) ACLRoleCRUD(resp http.ResponseWriter, req *http.Request)
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLRoleReadByName(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
roleName := strings.TrimPrefix(req.URL.Path, "/v1/acl/role/name/")
|
||||
|
@ -581,8 +578,8 @@ func (s *HTTPHandlers) ACLRoleRead(resp http.ResponseWriter, req *http.Request,
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLRoleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
return s.ACLRoleWrite(resp, req, "")
|
||||
|
@ -634,8 +631,8 @@ func (s *HTTPHandlers) ACLRoleDelete(resp http.ResponseWriter, req *http.Request
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLBindingRuleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
var args structs.ACLBindingRuleListRequest
|
||||
|
@ -668,8 +665,8 @@ func (s *HTTPHandlers) ACLBindingRuleList(resp http.ResponseWriter, req *http.Re
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLBindingRuleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
var fn func(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error)
|
||||
|
@ -728,8 +725,8 @@ func (s *HTTPHandlers) ACLBindingRuleRead(resp http.ResponseWriter, req *http.Re
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLBindingRuleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
return s.ACLBindingRuleWrite(resp, req, "")
|
||||
|
@ -781,8 +778,8 @@ func (s *HTTPHandlers) ACLBindingRuleDelete(resp http.ResponseWriter, req *http.
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLAuthMethodList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
var args structs.ACLAuthMethodListRequest
|
||||
|
@ -812,8 +809,8 @@ func (s *HTTPHandlers) ACLAuthMethodList(resp http.ResponseWriter, req *http.Req
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLAuthMethodCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
var fn func(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error)
|
||||
|
@ -872,8 +869,8 @@ func (s *HTTPHandlers) ACLAuthMethodRead(resp http.ResponseWriter, req *http.Req
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLAuthMethodCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
return s.ACLAuthMethodWrite(resp, req, "")
|
||||
|
@ -928,8 +925,8 @@ func (s *HTTPHandlers) ACLAuthMethodDelete(resp http.ResponseWriter, req *http.R
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLLogin(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
args := &structs.ACLLoginRequest{
|
||||
|
@ -954,8 +951,8 @@ func (s *HTTPHandlers) ACLLogin(resp http.ResponseWriter, req *http.Request) (in
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) ACLLogout(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
args := structs.ACLLogoutRequest{
|
||||
|
@ -1014,8 +1011,8 @@ func (s *HTTPHandlers) ACLAuthorize(resp http.ResponseWriter, req *http.Request)
|
|||
// policy.
|
||||
const maxRequests = 64
|
||||
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, aclDisabled
|
||||
}
|
||||
|
||||
request := structs.RemoteACLAuthorizationRequest{
|
||||
|
|
|
@ -70,10 +70,8 @@ func TestACL_Disabled_Response(t *testing.T) {
|
|||
req, _ := http.NewRequest("PUT", "/should/not/care", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := tt.fn(resp, req)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, obj)
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Code)
|
||||
require.Contains(t, resp.Body.String(), "ACL support disabled")
|
||||
require.ErrorIs(t, err, UnauthorizedError{Reason: "ACL support disabled"})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -119,9 +117,6 @@ func TestACL_Bootstrap(t *testing.T) {
|
|||
if tt.token && err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if got, want := resp.Code, tt.code; got != want {
|
||||
t.Fatalf("got %d want %d", got, want)
|
||||
}
|
||||
if tt.token {
|
||||
wrap, ok := out.(*aclBootstrapResponse)
|
||||
if !ok {
|
||||
|
|
|
@ -155,9 +155,11 @@ func (s *HTTPHandlers) AgentMetrics(resp http.ResponseWriter, req *http.Request)
|
|||
}
|
||||
if enablePrometheusOutput(req) {
|
||||
if s.agent.config.Telemetry.PrometheusOpts.Expiration < 1 {
|
||||
resp.WriteHeader(http.StatusUnsupportedMediaType)
|
||||
fmt.Fprint(resp, "Prometheus is not enabled since its retention time is not positive")
|
||||
return nil, nil
|
||||
return nil, CodeWithPayloadError{
|
||||
StatusCode: http.StatusUnsupportedMediaType,
|
||||
Reason: "Prometheus is not enabled since its retention time is not positive",
|
||||
ContentType: "text/plain",
|
||||
}
|
||||
}
|
||||
handlerOptions := promhttp.HandlerOpts{
|
||||
ErrorLog: s.agent.logger.StandardLogger(&hclog.StandardLoggerOptions{
|
||||
|
@ -423,11 +425,7 @@ func (s *HTTPHandlers) AgentService(resp http.ResponseWriter, req *http.Request)
|
|||
|
||||
svcState := s.agent.State.ServiceState(sid)
|
||||
if svcState == nil {
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(resp,
|
||||
"Unknown service ID %q. Ensure that the service ID is passed, not the service name.",
|
||||
sid.String())
|
||||
return "", nil, nil
|
||||
return "", nil, NotFoundError{Reason: fmt.Sprintf("unknown service ID: %s", sid.String())}
|
||||
}
|
||||
|
||||
svc := svcState.Service
|
||||
|
@ -557,9 +555,7 @@ func (s *HTTPHandlers) AgentMembers(resp http.ResponseWriter, req *http.Request)
|
|||
// key are ok, otherwise the argument doesn't apply to
|
||||
// the WAN.
|
||||
default:
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, "Cannot provide a segment with wan=true")
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: "Cannot provide a segment with wan=true"}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -735,16 +731,16 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re
|
|||
}
|
||||
|
||||
if err := decodeBody(req.Body, &args); err != nil {
|
||||
return nil, BadRequestError{fmt.Sprintf("Request decode failed: %v", err)}
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
|
||||
}
|
||||
|
||||
// Verify the check has a name.
|
||||
if args.Name == "" {
|
||||
return nil, BadRequestError{"Missing check name"}
|
||||
return nil, BadRequestError{Reason: "Missing check name"}
|
||||
}
|
||||
|
||||
if args.Status != "" && !structs.ValidStatus(args.Status) {
|
||||
return nil, BadRequestError{"Bad check status"}
|
||||
return nil, BadRequestError{Reason: "Bad check status"}
|
||||
}
|
||||
|
||||
authz, err := s.agent.delegate.ResolveTokenAndDefaultMeta(token, &args.EnterpriseMeta, nil)
|
||||
|
@ -763,15 +759,15 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re
|
|||
chkType := args.CheckType()
|
||||
err = chkType.Validate()
|
||||
if err != nil {
|
||||
return nil, BadRequestError{fmt.Sprintf("Invalid check: %v", err)}
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check: %v", err)}
|
||||
}
|
||||
|
||||
// Store the type of check based on the definition
|
||||
health.Type = chkType.Type()
|
||||
|
||||
if health.ServiceID != "" {
|
||||
cid := health.CompoundServiceID()
|
||||
// fixup the service name so that vetCheckRegister requires the right ACLs
|
||||
cid := health.CompoundServiceID()
|
||||
service := s.agent.State.Service(cid)
|
||||
if service != nil {
|
||||
health.ServiceName = service.Service
|
||||
|
@ -881,9 +877,7 @@ type checkUpdate struct {
|
|||
func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
var update checkUpdate
|
||||
if err := decodeBody(req.Body, &update); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Request decode failed: %v", err)
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
|
||||
}
|
||||
|
||||
switch update.Status {
|
||||
|
@ -891,9 +885,7 @@ func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Requ
|
|||
case api.HealthWarning:
|
||||
case api.HealthCritical:
|
||||
default:
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Invalid check status: '%s'", update.Status)
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check status: '%s'", update.Status)}
|
||||
}
|
||||
|
||||
ID, err := getPathSuffixUnescaped(req.URL.Path, "/v1/agent/check/update/")
|
||||
|
@ -1121,24 +1113,18 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
|
|||
}
|
||||
|
||||
if err := decodeBody(req.Body, &args); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Request decode failed: %v", err)
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
|
||||
}
|
||||
|
||||
// Verify the service has a name.
|
||||
if args.Name == "" {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, "Missing service name")
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: "Missing service name"}
|
||||
}
|
||||
|
||||
// Check the service address here and in the catalog RPC endpoint
|
||||
// since service registration isn't synchronous.
|
||||
if ipaddr.IsAny(args.Address) {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Invalid service address")
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: "Invalid service address"}
|
||||
}
|
||||
|
||||
var token string
|
||||
|
@ -1157,37 +1143,27 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
|
|||
ns := args.NodeService()
|
||||
if ns.Weights != nil {
|
||||
if err := structs.ValidateWeights(ns.Weights); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, fmt.Errorf("Invalid Weights: %v", err))
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid Weights: %v", err)}
|
||||
}
|
||||
}
|
||||
if err := structs.ValidateServiceMetadata(ns.Kind, ns.Meta, false); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, fmt.Errorf("Invalid Service Meta: %v", err))
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid Service Meta: %v", err)}
|
||||
}
|
||||
|
||||
// Run validation. This is the same validation that would happen on
|
||||
// the catalog endpoint so it helps ensure the sync will work properly.
|
||||
if err := ns.Validate(); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, err.Error())
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Validation failed: %v", err.Error())}
|
||||
}
|
||||
|
||||
// Verify the check type.
|
||||
chkTypes, err := args.CheckTypes()
|
||||
if err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, fmt.Errorf("Invalid check: %v", err))
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check: %v", err)}
|
||||
}
|
||||
for _, check := range chkTypes {
|
||||
if check.Status != "" && !structs.ValidStatus(check.Status) {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, "Status for checks must 'passing', 'warning', 'critical'")
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: "Status for checks must 'passing', 'warning', 'critical'"}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1221,9 +1197,7 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
|
|||
}
|
||||
if sidecar != nil {
|
||||
if err := sidecar.Validate(); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, err.Error())
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Failed Validation: %v", err.Error())}
|
||||
}
|
||||
// Make sure we are allowed to register the sidecar using the token
|
||||
// specified (might be specific to sidecar or the same one as the overall
|
||||
|
@ -1324,25 +1298,19 @@ func (s *HTTPHandlers) AgentServiceMaintenance(resp http.ResponseWriter, req *ht
|
|||
sid := structs.NewServiceID(serviceID, nil)
|
||||
|
||||
if sid.ID == "" {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, "Missing service ID")
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: "Missing service ID"}
|
||||
}
|
||||
|
||||
// Ensure we have some action
|
||||
params := req.URL.Query()
|
||||
if _, ok := params["enable"]; !ok {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, "Missing value for enable")
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: "Missing value for enable"}
|
||||
}
|
||||
|
||||
raw := params.Get("enable")
|
||||
enable, err := strconv.ParseBool(raw)
|
||||
if err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Invalid value for enable: %q", raw)
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid value for enable: %q", raw)}
|
||||
}
|
||||
|
||||
// Get the provided token, if any, and vet against any ACL policies.
|
||||
|
@ -1371,15 +1339,11 @@ func (s *HTTPHandlers) AgentServiceMaintenance(resp http.ResponseWriter, req *ht
|
|||
if enable {
|
||||
reason := params.Get("reason")
|
||||
if err = s.agent.EnableServiceMaintenance(sid, reason, token); err != nil {
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(resp, err.Error())
|
||||
return nil, nil
|
||||
return nil, NotFoundError{Reason: err.Error()}
|
||||
}
|
||||
} else {
|
||||
if err = s.agent.DisableServiceMaintenance(sid); err != nil {
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(resp, err.Error())
|
||||
return nil, nil
|
||||
return nil, NotFoundError{Reason: err.Error()}
|
||||
}
|
||||
}
|
||||
s.syncChanges()
|
||||
|
@ -1390,17 +1354,13 @@ func (s *HTTPHandlers) AgentNodeMaintenance(resp http.ResponseWriter, req *http.
|
|||
// Ensure we have some action
|
||||
params := req.URL.Query()
|
||||
if _, ok := params["enable"]; !ok {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, "Missing value for enable")
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: "Missing value for enable"}
|
||||
}
|
||||
|
||||
raw := params.Get("enable")
|
||||
enable, err := strconv.ParseBool(raw)
|
||||
if err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Invalid value for enable: %q", raw)
|
||||
return nil, nil
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid value for enable: %q", raw)}
|
||||
}
|
||||
|
||||
// Get the provided token, if any, and vet against any ACL policies.
|
||||
|
@ -1507,8 +1467,8 @@ func (s *HTTPHandlers) AgentMonitor(resp http.ResponseWriter, req *http.Request)
|
|||
}
|
||||
|
||||
func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
if s.checkACLDisabled() {
|
||||
return nil, UnauthorizedError{Reason: "ACL support disabled"}
|
||||
}
|
||||
|
||||
// Fetch the ACL token, if any, and enforce agent policy.
|
||||
|
|
|
@ -660,9 +660,9 @@ func TestAgent_Service(t *testing.T) {
|
|||
wantResp: &updatedResponse,
|
||||
},
|
||||
{
|
||||
name: "err: non-existent proxy",
|
||||
url: "/v1/agent/service/nope",
|
||||
wantCode: 404,
|
||||
name: "err: non-existent proxy",
|
||||
url: "/v1/agent/service/nope",
|
||||
wantErr: "unknown service ID: nope",
|
||||
},
|
||||
{
|
||||
name: "err: bad ACL for service",
|
||||
|
@ -3784,9 +3784,6 @@ func testAgent_RegisterService_InvalidAddress(t *testing.T, extraHCL string) {
|
|||
if got, want := resp.Code, 400; got != want {
|
||||
t.Fatalf("got code %d want %d", got, want)
|
||||
}
|
||||
if got, want := resp.Body.String(), "Invalid service address"; got != want {
|
||||
t.Fatalf("got body %q want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -69,6 +69,15 @@ func (e NotFoundError) Error() string {
|
|||
return e.Reason
|
||||
}
|
||||
|
||||
// UnauthorizedError should be returned by a handler when the request lacks valid authorization.
|
||||
type UnauthorizedError struct {
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (e UnauthorizedError) Error() string {
|
||||
return e.Reason
|
||||
}
|
||||
|
||||
// CodeWithPayloadError allow returning non HTTP 200
|
||||
// Error codes while not returning PlainText payload
|
||||
type CodeWithPayloadError struct {
|
||||
|
@ -241,7 +250,8 @@ func (s *HTTPHandlers) handler(enableDebug bool) http.Handler {
|
|||
|
||||
// If enableDebug is not set, and ACLs are disabled, write
|
||||
// an unauthorized response
|
||||
if !enableDebug && s.checkACLDisabled(resp, req) {
|
||||
if !enableDebug && s.checkACLDisabled() {
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -423,6 +433,11 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
|
|||
return ok
|
||||
}
|
||||
|
||||
isUnauthorized := func(err error) bool {
|
||||
_, ok := err.(UnauthorizedError)
|
||||
return ok
|
||||
}
|
||||
|
||||
isTooManyRequests := func(err error) bool {
|
||||
// Sadness net/rpc can't do nice typed errors so this is all we got
|
||||
return err.Error() == consul.ErrRateLimited.Error()
|
||||
|
@ -467,6 +482,9 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
|
|||
case isNotFound(err):
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(resp, err.Error())
|
||||
case isUnauthorized(err):
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
fmt.Fprint(resp, err.Error())
|
||||
case isTooManyRequests(err):
|
||||
resp.WriteHeader(http.StatusTooManyRequests)
|
||||
fmt.Fprint(resp, err.Error())
|
||||
|
|
Loading…
Reference in New Issue