diff --git a/consul/fsm.go b/consul/fsm.go index ddd6335fb0..a89dca2452 100644 --- a/consul/fsm.go +++ b/consul/fsm.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/hashicorp/consul/rpc" "github.com/hashicorp/raft" + "github.com/ugorji/go/codec" "io" ) @@ -18,7 +19,7 @@ type consulFSM struct { // state in a way that can be accessed concurrently with operations // that may modify the live state. type consulSnapshot struct { - fsm *consulFSM + state *StateStore } // NewFSM is used to construct a new FSM with a blank state @@ -82,8 +83,12 @@ func (c *consulFSM) applyDeregister(buf []byte) interface{} { } func (c *consulFSM) Snapshot() (raft.FSMSnapshot, error) { - snap := &consulSnapshot{fsm: c} - return snap, nil + // Create a new snapshot + snap, err := c.state.Snapshot() + if err != nil { + return nil, err + } + return &consulSnapshot{snap}, nil } func (c *consulFSM) Restore(old io.ReadCloser) error { @@ -95,7 +100,41 @@ func (c *consulFSM) Restore(old io.ReadCloser) error { return err } - // TODO: Populate the new state + // Create a decoder + var handle codec.MsgpackHandle + dec := codec.NewDecoder(old, &handle) + + // Populate the new state + msgType := make([]byte, 1) + for { + // Read the message type + _, err := old.Read(msgType) + if err == io.EOF { + break + } else if err != nil { + return err + } + + // Decode + switch rpc.MessageType(msgType[0]) { + case rpc.RegisterRequestType: + var req rpc.RegisterRequest + if err := dec.Decode(&req); err != nil { + return err + } + + // Register the service or the node + if req.ServiceName != "" { + state.EnsureService(req.Node, req.ServiceName, + req.ServiceTag, req.ServicePort) + } else { + state.EnsureNode(req.Node, req.Address) + } + + default: + return fmt.Errorf("Unrecognized msg type: %v", msgType) + } + } // Do an atomic flip, safe since Apply is not called concurrently c.state = state @@ -103,8 +142,45 @@ func (c *consulFSM) Restore(old io.ReadCloser) error { } func (s *consulSnapshot) Persist(sink raft.SnapshotSink) error { + // Get all the nodes + nodes := s.state.Nodes() + + // Register the nodes + handle := codec.MsgpackHandle{} + encoder := codec.NewEncoder(sink, &handle) + + // Register each node + var req rpc.RegisterRequest + for i := 0; i < len(nodes); i += 2 { + req = rpc.RegisterRequest{ + Node: nodes[i], + Address: nodes[i+1], + } + + // Register the node itself + sink.Write([]byte{byte(rpc.RegisterRequestType)}) + if err := encoder.Encode(&req); err != nil { + sink.Cancel() + return err + } + + // Register each service this node has + services := s.state.NodeServices(nodes[i]) + for serv, props := range services { + req.ServiceName = serv + req.ServiceTag = props.Tag + req.ServicePort = props.Port + + sink.Write([]byte{byte(rpc.RegisterRequestType)}) + if err := encoder.Encode(&req); err != nil { + sink.Cancel() + return err + } + } + } return nil } func (s *consulSnapshot) Release() { + s.state.Close() } diff --git a/consul/fsm_test.go b/consul/fsm_test.go index 080479a0fb..bce056bc8d 100644 --- a/consul/fsm_test.go +++ b/consul/fsm_test.go @@ -1,10 +1,29 @@ package consul import ( + "bytes" "github.com/hashicorp/consul/rpc" "testing" ) +type MockSink struct { + *bytes.Buffer + cancel bool +} + +func (m *MockSink) ID() string { + return "Mock" +} + +func (m *MockSink) Cancel() error { + m.cancel = true + return nil +} + +func (m *MockSink) Close() error { + return nil +} + func TestFSM_RegisterNode(t *testing.T) { fsm, err := NewFSM() if err != nil { @@ -174,3 +193,60 @@ func TestFSM_DeregisterNode(t *testing.T) { t.Fatalf("Services: %v", services) } } + +func TestFSM_SnapshotRestore(t *testing.T) { + fsm, err := NewFSM() + if err != nil { + t.Fatalf("err: %v", err) + } + + // Add some state + fsm.state.EnsureNode("foo", "127.0.0.1") + fsm.state.EnsureNode("baz", "127.0.0.2") + fsm.state.EnsureService("foo", "web", "", 80) + fsm.state.EnsureService("foo", "db", "primary", 5000) + fsm.state.EnsureService("baz", "web", "", 80) + fsm.state.EnsureService("baz", "db", "secondary", 5000) + + // Snapshot + snap, err := fsm.Snapshot() + if err != nil { + t.Fatalf("err: %v", err) + } + defer snap.Release() + + // Persist + buf := bytes.NewBuffer(nil) + sink := &MockSink{buf, false} + if err := snap.Persist(sink); err != nil { + t.Fatalf("err: %v", err) + } + + // Try to restore on a new FSM + fsm2, err := NewFSM() + if err != nil { + t.Fatalf("err: %v", err) + } + + // Do a restore + if err := fsm2.Restore(sink); err != nil { + t.Fatalf("err: %v", err) + } + + // Verify the contents + nodes := fsm2.state.Nodes() + if len(nodes) != 4 { + t.Fatalf("Bad: %v", nodes) + } + + fooSrv := fsm2.state.NodeServices("foo") + if len(fooSrv) != 2 { + t.Fatalf("Bad: %v", fooSrv) + } + if fooSrv["db"].Tag != "primary" { + t.Fatalf("Bad: %v", fooSrv) + } + if fooSrv["db"].Port != 5000 { + t.Fatalf("Bad: %v", fooSrv) + } +}