diff --git a/command/agent/agent_endpoint.go b/command/agent/agent_endpoint.go index 3e6426f803..6757179371 100644 --- a/command/agent/agent_endpoint.go +++ b/command/agent/agent_endpoint.go @@ -178,6 +178,12 @@ func (s *HTTPServer) AgentDeregisterService(resp http.ResponseWriter, req *http. } func (s *HTTPServer) AgentServiceMaintenance(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + // Only PUT supported + if req.Method != "PUT" { + resp.WriteHeader(405) + return nil, nil + } + // Ensure we have a service ID serviceID := strings.TrimPrefix(req.URL.Path, "/v1/agent/service/maintenance/") if serviceID == "" { @@ -207,9 +213,17 @@ func (s *HTTPServer) AgentServiceMaintenance(resp http.ResponseWriter, req *http return nil, nil } + var err error if enable { - return nil, s.agent.EnableServiceMaintenance(serviceID) + if err = s.agent.EnableServiceMaintenance(serviceID); err != nil { + resp.WriteHeader(409) + resp.Write([]byte(err.Error())) + } } else { - return nil, s.agent.DisableServiceMaintenance(serviceID) + if err = s.agent.DisableServiceMaintenance(serviceID); err != nil { + resp.WriteHeader(409) + resp.Write([]byte(err.Error())) + } } + return nil, err } diff --git a/command/agent/agent_endpoint_test.go b/command/agent/agent_endpoint_test.go index 1b266347e7..8881d18c31 100644 --- a/command/agent/agent_endpoint_test.go +++ b/command/agent/agent_endpoint_test.go @@ -7,6 +7,7 @@ import ( "github.com/hashicorp/consul/testutil" "github.com/hashicorp/serf/serf" "net/http" + "net/http/httptest" "os" "testing" "time" @@ -492,3 +493,132 @@ func TestHTTPAgentDeregisterService(t *testing.T) { t.Fatalf("have test check") } } + +func TestHTTPAgent_ServiceMaintenanceEndpoint_BadRequest(t *testing.T) { + dir, srv := makeHTTPServer(t) + defer os.RemoveAll(dir) + defer srv.Shutdown() + defer srv.agent.Shutdown() + + // Fails on non-PUT + req, _ := http.NewRequest("GET", "/v1/agent/service/maintenance/test?enable=true", nil) + resp := httptest.NewRecorder() + if _, err := srv.AgentServiceMaintenance(resp, req); err != nil { + t.Fatalf("err: %s", err) + } + if resp.Code != 405 { + t.Fatalf("expected 405 for non-PUT request") + } + + // Fails when no enable flag provided + req, _ = http.NewRequest("PUT", "/v1/agent/service/maintenance/test", nil) + resp = httptest.NewRecorder() + if _, err := srv.AgentServiceMaintenance(resp, req); err != nil { + t.Fatalf("err: %s", err) + } + if resp.Code != 400 { + t.Fatalf("expected 400 for missing enable flag") + } + + // Fails when no service ID provided + req, _ = http.NewRequest("PUT", "/v1/agent/service/maintenance/?enable=true", nil) + resp = httptest.NewRecorder() + if _, err := srv.AgentServiceMaintenance(resp, req); err != nil { + t.Fatalf("err: %s", err) + } + if resp.Code != 400 { + t.Fatalf("expected 400 for missing service ID") + } +} + +func TestHTTPAgent_EnableServiceMaintenance(t *testing.T) { + dir, srv := makeHTTPServer(t) + defer os.RemoveAll(dir) + defer srv.Shutdown() + defer srv.agent.Shutdown() + + // Register the service + service := &structs.NodeService{ + ID: "test", + Service: "test", + } + if err := srv.agent.AddService(service, nil, false); err != nil { + t.Fatalf("err: %v", err) + } + + // Force into maintenance mode + if err := srv.agent.EnableServiceMaintenance("test"); err != nil { + t.Fatalf("err: %s", err) + } + + // Fails when service is already in maintenance mode + req, _ := http.NewRequest("PUT", "/v1/agent/service/maintenance/test?enable=true", nil) + resp := httptest.NewRecorder() + if _, err := srv.AgentServiceMaintenance(resp, req); err == nil { + t.Fatalf("should have errored") + } + if resp.Code != 409 { + t.Fatalf("expected 409, got %d", resp.Code) + } + + // Remove maintenance mode + if err := srv.agent.DisableServiceMaintenance("test"); err != nil { + t.Fatalf("err: %s", err) + } + + // Force the service into maintenance mode + req, _ = http.NewRequest("PUT", "/v1/agent/service/maintenance/test?enable=true", nil) + resp = httptest.NewRecorder() + if _, err := srv.AgentServiceMaintenance(resp, req); err != nil { + t.Fatalf("err: %s", err) + } + if resp.Code != 200 { + t.Fatalf("expected 200, got %d", resp.Code) + } + + // Ensure the maintenance check was registered + if _, ok := srv.agent.state.Checks()[maintCheckID]; !ok { + t.Fatalf("should have registered maintenance check") + } +} + +func TestHTTPAgent_DisableServiceMaintenance(t *testing.T) { + dir, srv := makeHTTPServer(t) + defer os.RemoveAll(dir) + defer srv.Shutdown() + defer srv.agent.Shutdown() + + // Register the service + service := &structs.NodeService{ + ID: "test", + Service: "test", + } + if err := srv.agent.AddService(service, nil, false); err != nil { + t.Fatalf("err: %v", err) + } + + // Fails when the service is not in maintenance mode + req, _ := http.NewRequest("PUT", "/v1/agent/service/maintenance/test?enable=false", nil) + resp := httptest.NewRecorder() + if _, err := srv.AgentServiceMaintenance(resp, req); err == nil { + t.Fatalf("should have failed") + } + if resp.Code != 409 { + t.Fatalf("expected 409, got %d", resp.Code) + } + + // Force the service into maintenance mode + if err := srv.agent.EnableServiceMaintenance("test"); err != nil { + t.Fatalf("err: %s", err) + } + + // Leave maintenance mode + req, _ = http.NewRequest("PUT", "/v1/agent/service/maintenance/test?enable=false", nil) + resp = httptest.NewRecorder() + if _, err := srv.AgentServiceMaintenance(resp, req); err != nil { + t.Fatalf("err: %s", err) + } + if resp.Code != 200 { + t.Fatalf("expected 200, got %d", resp.Code) + } +}