diff --git a/command/agent/agent.go b/command/agent/agent.go index c2394b9a30..88e26488c5 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -1052,6 +1052,22 @@ func (a *Agent) unloadChecks() error { return nil } +// snapshotCheckState is used to snapshot the current state of the health +// checks. This is done before we reload our checks, so that we can properly +// restore into the same state. +func (a *Agent) snapshotCheckState() map[string]*structs.HealthCheck { + return a.state.Checks() +} + +// restoreCheckState is used to reset the health state based on a snapshot. +// This is done after we finish the reload to avoid any unnecessary flaps +// in health state and potential session invalidations. +func (a *Agent) restoreCheckState(snap map[string]*structs.HealthCheck) { + for id, check := range snap { + a.state.UpdateCheck(id, check.Status, check.Output) + } +} + // serviceMaintCheckID returns the ID of a given service's maintenance check func serviceMaintCheckID(serviceID string) string { return fmt.Sprintf("%s:%s", serviceMaintCheckPrefix, serviceID) diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index da098d6eb9..b718674811 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -1007,3 +1007,61 @@ func TestAgent_NodeMaintenanceMode(t *testing.T) { t.Fatalf("bad: %#v", check) } } + +func TestAgent_checkStateSnapshot(t *testing.T) { + config := nextConfig() + dir, agent := makeAgent(t, config) + defer os.RemoveAll(dir) + defer agent.Shutdown() + + // First register a service + svc := &structs.NodeService{ + ID: "redis", + Service: "redis", + Tags: []string{"foo"}, + Port: 8000, + } + if err := agent.AddService(svc, nil, false); err != nil { + t.Fatalf("err: %v", err) + } + + // Register a check + check1 := &structs.HealthCheck{ + Node: config.NodeName, + CheckID: "service:redis", + Name: "redischeck", + Status: structs.HealthPassing, + ServiceID: "redis", + ServiceName: "redis", + } + if err := agent.AddCheck(check1, nil, true); err != nil { + t.Fatalf("err: %s", err) + } + + // Snapshot the state + snap := agent.snapshotCheckState() + + // Unload all of the checks + if err := agent.unloadChecks(); err != nil { + t.Fatalf("err: %s", err) + } + + // Reload the checks + if err := agent.loadChecks(config); err != nil { + t.Fatalf("err: %s", err) + } + + // Restore the state + agent.restoreCheckState(snap) + + // Search for the check + out, ok := agent.state.Checks()[check1.CheckID] + if !ok { + t.Fatalf("check should have been registered") + } + + // Make sure state was restored + if out.Status != structs.HealthPassing { + t.Fatalf("should have restored check state") + } +} diff --git a/command/agent/command.go b/command/agent/command.go index 6c16dd6d31..e0ebfbfc7f 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -751,6 +751,10 @@ func (c *Command) handleReload(config *Config) *Config { c.agent.PauseSync() defer c.agent.ResumeSync() + // Snapshot the current state, and restore it afterwards + snap := c.agent.snapshotCheckState() + defer c.agent.restoreCheckState(snap) + // First unload all checks and services. This lets us begin the reload // with a clean slate. if err := c.agent.unloadServices(); err != nil {