diff --git a/consul/state/state_store.go b/consul/state/state_store.go index d8b4cc45bc..7e4e5fe53a 100644 --- a/consul/state/state_store.go +++ b/consul/state/state_store.go @@ -466,6 +466,21 @@ func (s *StateStore) ServiceChecks(serviceName string) (uint64, structs.HealthCh return s.parseChecks(tx.Get("checks", "service", serviceName)) } +// ChecksInState is used to query the state store for all checks +// which are in the provided state. +func (s *StateStore) ChecksInState(state string) (uint64, structs.HealthChecks, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Query all checks if HealthAny is passed + if state == structs.HealthAny { + return s.parseChecks(tx.Get("checks", "status")) + } + + // Any other state we need to query for explicitly + return s.parseChecks(tx.Get("checks", "status", state)) +} + // parseChecks is a helper function used to deduplicate some // repetitive code for returning health checks. func (s *StateStore) parseChecks(iter memdb.ResultIterator, err error) (uint64, structs.HealthChecks, error) { diff --git a/consul/state/state_store_test.go b/consul/state/state_store_test.go index 390b5a1df2..239cd31b28 100644 --- a/consul/state/state_store_test.go +++ b/consul/state/state_store_test.go @@ -59,11 +59,13 @@ func testRegisterService(t *testing.T, s *StateStore, idx uint64, nodeID, servic } } -func testRegisterCheck(t *testing.T, s *StateStore, idx uint64, nodeID, serviceID, checkID string) { +func testRegisterCheck(t *testing.T, s *StateStore, idx uint64, + nodeID, serviceID, checkID, state string) { chk := &structs.HealthCheck{ Node: nodeID, CheckID: checkID, ServiceID: serviceID, + Status: state, } if err := s.EnsureCheck(idx, chk); err != nil { t.Fatalf("err: %s", err) @@ -192,7 +194,7 @@ func TestStateStore_DeleteNode(t *testing.T) { // Create a node and register a service and health check with it. testRegisterNode(t, s, 0, "node1") testRegisterService(t, s, 1, "node1", "service1") - testRegisterCheck(t, s, 2, "node1", "", "check1") + testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) // Delete the node if err := s.DeleteNode(3, "node1"); err != nil { @@ -316,7 +318,7 @@ func TestStateStore_DeleteService(t *testing.T) { // Register a node with one service and a check testRegisterNode(t, s, 1, "node1") testRegisterService(t, s, 2, "node1", "service1") - testRegisterCheck(t, s, 3, "node1", "service1", "check1") + testRegisterCheck(t, s, 3, "node1", "service1", "check1", structs.HealthPassing) // Delete the service if err := s.DeleteService(4, "node1", "service1"); err != nil { @@ -434,13 +436,13 @@ func TestStateStore_ServiceChecks(t *testing.T) { // Create the first node and service with some checks testRegisterNode(t, s, 0, "node1") testRegisterService(t, s, 1, "node1", "service1") - testRegisterCheck(t, s, 2, "node1", "service1", "check1") - testRegisterCheck(t, s, 3, "node1", "service1", "check2") + testRegisterCheck(t, s, 2, "node1", "service1", "check1", structs.HealthPassing) + testRegisterCheck(t, s, 3, "node1", "service1", "check2", structs.HealthPassing) // Create a second node/service with a different set of checks testRegisterNode(t, s, 4, "node2") testRegisterService(t, s, 5, "node2", "service2") - testRegisterCheck(t, s, 6, "node2", "service2", "check3") + testRegisterCheck(t, s, 6, "node2", "service2", "check3", structs.HealthPassing) // Try querying for all checks associated with service1 idx, checks, err := s.ServiceChecks("service1") @@ -460,7 +462,7 @@ func TestStateStore_DeleteCheck(t *testing.T) { // Register a node and a node-level health check testRegisterNode(t, s, 1, "node1") - testRegisterCheck(t, s, 2, "node1", "", "check1") + testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) // Delete the check if err := s.DeleteCheck(3, "node1", "check1"); err != nil { @@ -481,3 +483,36 @@ func TestStateStore_DeleteCheck(t *testing.T) { t.Fatalf("bad index: %d", idx) } } + +func TestStateStore_ChecksInState(t *testing.T) { + s := testStateStore(t) + + // Register a node with checks in varied states + testRegisterNode(t, s, 0, "node1") + testRegisterCheck(t, s, 1, "node1", "", "check1", structs.HealthPassing) + testRegisterCheck(t, s, 2, "node1", "", "check2", structs.HealthCritical) + testRegisterCheck(t, s, 3, "node1", "", "check3", structs.HealthPassing) + + // Query the state store for passing checks. + _, results, err := s.ChecksInState(structs.HealthPassing) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Make sure we only get the checks which match the state + if n := len(results); n != 2 { + t.Fatalf("expected 2 checks, got: %d", n) + } + if results[0].CheckID != "check1" || results[1].CheckID != "check3" { + t.Fatalf("bad: %#v", results) + } + + // HealthAny just returns everything. + _, results, err = s.ChecksInState(structs.HealthAny) + if err != nil { + t.Fatalf("err: %s", err) + } + if n := len(results); n != 3 { + t.Fatalf("expected 3 checks, got: %d", n) + } +}