diff --git a/consul/fsm.go b/consul/fsm.go index 6bfb102973..41fc87acdc 100644 --- a/consul/fsm.go +++ b/consul/fsm.go @@ -318,29 +318,19 @@ func (c *consulFSM) Snapshot() (raft.FSMSnapshot, error) { return &consulSnapshot{c.state.Snapshot()}, nil } +// Restore streams in the snapshot and replaces the current state store with a +// new one based on the snapshot if all goes OK during the restore. func (c *consulFSM) Restore(old io.ReadCloser) error { defer old.Close() - // Create a new state store + // Create a new state store. stateNew, err := state.NewStateStore(c.gc) if err != nil { return err } - // External code might be calling State(), so we need to synchronize - // here to make sure we swap in the new state store atomically. - c.stateLock.Lock() - stateOld := c.state - c.state = stateNew - c.stateLock.Unlock() - - // The old state store has been abandoned already since we've replaced - // it with an empty one, but we defer telling watchers about it until - // the restore is done, so they wake up once we have the latest data. - defer stateOld.Abandon() - // Set up a new restore transaction - restore := c.state.Restore() + restore := stateNew.Restore() defer restore.Abort() // Create a decoder @@ -443,6 +433,18 @@ func (c *consulFSM) Restore(old io.ReadCloser) error { } restore.Commit() + + // External code might be calling State(), so we need to synchronize + // here to make sure we swap in the new state store atomically. + c.stateLock.Lock() + stateOld := c.state + c.state = stateNew + c.stateLock.Unlock() + + // Signal that the old state store has been abandoned. This is required + // because we don't operate on it any more, we just throw it away, so + // blocking queries won't see any changes and need to be woken up. + stateOld.Abandon() return nil } diff --git a/consul/fsm_test.go b/consul/fsm_test.go index 46608aa075..fa0d3d1f8f 100644 --- a/consul/fsm_test.go +++ b/consul/fsm_test.go @@ -592,37 +592,41 @@ func TestFSM_SnapshotRestore(t *testing.T) { } -func TestFSM_KVSSet(t *testing.T) { +func TestFSM_BadRestore(t *testing.T) { + // Create an FSM with some state. fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } + fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + abandonCh := fsm.state.AbandonCh() - req := structs.KVSRequest{ - Datacenter: "dc1", - Op: structs.KVSSet, - DirEnt: structs.DirEntry{ - Key: "/test/path", - Flags: 0, - Value: []byte("test"), - }, - } - buf, err := structs.Encode(structs.KVSRequestType, req) - if err != nil { + // Do a bad restore. + buf := bytes.NewBuffer([]byte("bad snapshot")) + sink := &MockSink{buf, false} + if err := fsm.Restore(sink); err == nil { t.Fatalf("err: %v", err) } - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) + + // Verify the contents didn't get corrupted. + _, nodes, err := fsm.state.Nodes(nil) + if err != nil { + t.Fatalf("err: %s", err) + } + if len(nodes) != 1 { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Node != "foo" || + nodes[0].Address != "127.0.0.1" || + len(nodes[0].TaggedAddresses) != 0 { + t.Fatalf("bad: %v", nodes[0]) } - // Verify key is set - _, d, err := fsm.state.KVSGet(nil, "/test/path") - if err != nil { - t.Fatalf("err: %v", err) - } - if d == nil { - t.Fatalf("missing") + // Verify the old state store didn't get abandoned. + select { + case <-abandonCh: + t.Fatalf("bad") + default: } } diff --git a/consul/rpc.go b/consul/rpc.go index cf7b558902..43ff3b4a48 100644 --- a/consul/rpc.go +++ b/consul/rpc.go @@ -417,7 +417,16 @@ RUN_QUERY: err := fn(ws, state) if err == nil && queryMeta.Index > 0 && queryMeta.Index <= queryOpts.MinQueryIndex { if expired := ws.Watch(timeout.C); !expired { - goto RUN_QUERY + // If a restore may have woken us up then bail out from + // the query immediately. This is slightly race-ey since + // this might have been interrupted for other reasons, + // but it's OK to kick it back to the caller in either + // case. + select { + case <-state.AbandonCh(): + default: + goto RUN_QUERY + } } } return err diff --git a/consul/rpc_test.go b/consul/rpc_test.go index e77f6ea10f..8b8d99f525 100644 --- a/consul/rpc_test.go +++ b/consul/rpc_test.go @@ -1,12 +1,15 @@ package consul import ( + "bytes" "os" "testing" "time" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/testutil" + "github.com/hashicorp/go-memdb" "github.com/hashicorp/net-rpc-msgpackrpc" ) @@ -70,3 +73,93 @@ func TestRPC_NoLeader_Retry(t *testing.T) { t.Fatalf("bad: %v", err) } } + +func TestRPC_blockingQuery(t *testing.T) { + dir, s := testServer(t) + defer os.RemoveAll(dir) + defer s.Shutdown() + + // Perform a non-blocking query. + { + var opts structs.QueryOptions + var meta structs.QueryMeta + var calls int + fn := func(ws memdb.WatchSet, state *state.StateStore) error { + calls++ + return nil + } + if err := s.blockingQuery(&opts, &meta, fn); err != nil { + t.Fatalf("err: %v", err) + } + if calls != 1 { + t.Fatalf("bad: %d", calls) + } + } + + // Perform a blocking query that gets woken up and loops around once. + { + opts := structs.QueryOptions{ + MinQueryIndex: 3, + } + var meta structs.QueryMeta + var calls int + fn := func(ws memdb.WatchSet, state *state.StateStore) error { + if calls == 0 { + meta.Index = 3 + + fakeCh := make(chan struct{}) + close(fakeCh) + ws.Add(fakeCh) + } else { + meta.Index = 4 + } + calls++ + return nil + } + if err := s.blockingQuery(&opts, &meta, fn); err != nil { + t.Fatalf("err: %v", err) + } + if calls != 2 { + t.Fatalf("bad: %d", calls) + } + } + + // Perform a query that blocks and gets interrupted when the state store + // is abandoned. + { + opts := structs.QueryOptions{ + MinQueryIndex: 3, + } + var meta structs.QueryMeta + var calls int + fn := func(ws memdb.WatchSet, state *state.StateStore) error { + if calls == 0 { + meta.Index = 3 + + snap, err := s.fsm.Snapshot() + if err != nil { + t.Fatalf("err: %v", err) + } + defer snap.Release() + + buf := bytes.NewBuffer(nil) + sink := &MockSink{buf, false} + if err := snap.Persist(sink); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.fsm.Restore(sink); err != nil { + t.Fatalf("err: %v", err) + } + } + calls++ + return nil + } + if err := s.blockingQuery(&opts, &meta, fn); err != nil { + t.Fatalf("err: %v", err) + } + if calls != 1 { + t.Fatalf("bad: %d", calls) + } + } +}