From a00244e77f226c519521f12d3d447d0744162409 Mon Sep 17 00:00:00 2001 From: James Phillips Date: Mon, 27 Nov 2017 17:44:19 -0800 Subject: [PATCH 01/14] Gets rid of obsolete configtest exception. --- main.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/main.go b/main.go index bb696b8178..f91694c94e 100644 --- a/main.go +++ b/main.go @@ -36,9 +36,7 @@ func realMain() int { var cmds []string for c := range command.Commands { - if c != "configtest" { - cmds = append(cmds, c) - } + cmds = append(cmds, c) } cli := &cli.CLI{ From 8bf1f5773735df4b60809bcd87cbc2263f7ef52e Mon Sep 17 00:00:00 2001 From: James Phillips Date: Tue, 28 Nov 2017 11:45:47 -0800 Subject: [PATCH 02/14] Renames stubs to be more consistent. --- agent/consul/{segment_stub.go => segment_oss.go} | 0 sentinel/{sentinel_stub.go => sentinel_oss.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename agent/consul/{segment_stub.go => segment_oss.go} (100%) rename sentinel/{sentinel_stub.go => sentinel_oss.go} (100%) diff --git a/agent/consul/segment_stub.go b/agent/consul/segment_oss.go similarity index 100% rename from agent/consul/segment_stub.go rename to agent/consul/segment_oss.go diff --git a/sentinel/sentinel_stub.go b/sentinel/sentinel_oss.go similarity index 100% rename from sentinel/sentinel_stub.go rename to sentinel/sentinel_oss.go From 44d824a58f2f600c81db97ae872c17acfde1c75d Mon Sep 17 00:00:00 2001 From: James Phillips Date: Tue, 28 Nov 2017 11:48:15 -0800 Subject: [PATCH 03/14] Renames "segments" to "segment" to be consistent with other files. --- agent/config/{segments_oss.go => segment_oss.go} | 0 agent/config/{segments_oss_test.go => segment_oss_test.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename agent/config/{segments_oss.go => segment_oss.go} (100%) rename agent/config/{segments_oss_test.go => segment_oss_test.go} (100%) diff --git a/agent/config/segments_oss.go b/agent/config/segment_oss.go similarity index 100% rename from agent/config/segments_oss.go rename to agent/config/segment_oss.go diff --git a/agent/config/segments_oss_test.go b/agent/config/segment_oss_test.go similarity index 100% rename from agent/config/segments_oss_test.go rename to agent/config/segment_oss_test.go From 521e46ce91ff93564686fea06124103444927096 Mon Sep 17 00:00:00 2001 From: James Phillips Date: Tue, 28 Nov 2017 12:44:01 -0800 Subject: [PATCH 04/14] Adds a registry mechanism for CLI commands. --- command/commands.go | 125 ---------------------------------------- command/commands_oss.go | 95 ++++++++++++++++++++++++++++++ command/registry.go | 60 +++++++++++++++++++ main.go | 12 ++-- 4 files changed, 162 insertions(+), 130 deletions(-) delete mode 100644 command/commands.go create mode 100644 command/commands_oss.go create mode 100644 command/registry.go diff --git a/command/commands.go b/command/commands.go deleted file mode 100644 index 1044c7478e..0000000000 --- a/command/commands.go +++ /dev/null @@ -1,125 +0,0 @@ -package command - -import ( - "os" - "os/signal" - "syscall" - - "github.com/hashicorp/consul/command/agent" - "github.com/hashicorp/consul/command/catalog" - catlistdc "github.com/hashicorp/consul/command/catalog/list/dc" - catlistnodes "github.com/hashicorp/consul/command/catalog/list/nodes" - catlistsvc "github.com/hashicorp/consul/command/catalog/list/services" - "github.com/hashicorp/consul/command/event" - "github.com/hashicorp/consul/command/exec" - "github.com/hashicorp/consul/command/forceleave" - "github.com/hashicorp/consul/command/info" - "github.com/hashicorp/consul/command/join" - "github.com/hashicorp/consul/command/keygen" - "github.com/hashicorp/consul/command/keyring" - "github.com/hashicorp/consul/command/kv" - kvdel "github.com/hashicorp/consul/command/kv/del" - kvexp "github.com/hashicorp/consul/command/kv/exp" - kvget "github.com/hashicorp/consul/command/kv/get" - kvimp "github.com/hashicorp/consul/command/kv/imp" - kvput "github.com/hashicorp/consul/command/kv/put" - "github.com/hashicorp/consul/command/leave" - "github.com/hashicorp/consul/command/lock" - "github.com/hashicorp/consul/command/maint" - "github.com/hashicorp/consul/command/members" - "github.com/hashicorp/consul/command/monitor" - "github.com/hashicorp/consul/command/operator" - operauto "github.com/hashicorp/consul/command/operator/autopilot" - operautoget "github.com/hashicorp/consul/command/operator/autopilot/get" - operautoset "github.com/hashicorp/consul/command/operator/autopilot/set" - operraft "github.com/hashicorp/consul/command/operator/raft" - operraftlist "github.com/hashicorp/consul/command/operator/raft/listpeers" - operraftremove "github.com/hashicorp/consul/command/operator/raft/removepeer" - "github.com/hashicorp/consul/command/reload" - "github.com/hashicorp/consul/command/rtt" - "github.com/hashicorp/consul/command/snapshot" - snapinspect "github.com/hashicorp/consul/command/snapshot/inspect" - snaprestore "github.com/hashicorp/consul/command/snapshot/restore" - snapsave "github.com/hashicorp/consul/command/snapshot/save" - "github.com/hashicorp/consul/command/validate" - "github.com/hashicorp/consul/command/version" - "github.com/hashicorp/consul/command/watch" - consulversion "github.com/hashicorp/consul/version" - - "github.com/mitchellh/cli" -) - -// Commands is the mapping of all the available Consul commands. -var Commands map[string]cli.CommandFactory - -func init() { - rev := consulversion.GitCommit - ver := consulversion.Version - verPre := consulversion.VersionPrerelease - verHuman := consulversion.GetHumanVersion() - - ui := &cli.BasicUi{Writer: os.Stdout, ErrorWriter: os.Stderr} - - Commands = map[string]cli.CommandFactory{ - "agent": func() (cli.Command, error) { - return agent.New(ui, rev, ver, verPre, verHuman, make(chan struct{})), nil - }, - - "catalog": func() (cli.Command, error) { return catalog.New(), nil }, - "catalog datacenters": func() (cli.Command, error) { return catlistdc.New(ui), nil }, - "catalog nodes": func() (cli.Command, error) { return catlistnodes.New(ui), nil }, - "catalog services": func() (cli.Command, error) { return catlistsvc.New(ui), nil }, - "event": func() (cli.Command, error) { return event.New(ui), nil }, - "exec": func() (cli.Command, error) { return exec.New(ui, makeShutdownCh()), nil }, - "force-leave": func() (cli.Command, error) { return forceleave.New(ui), nil }, - "info": func() (cli.Command, error) { return info.New(ui), nil }, - "join": func() (cli.Command, error) { return join.New(ui), nil }, - "keygen": func() (cli.Command, error) { return keygen.New(ui), nil }, - "keyring": func() (cli.Command, error) { return keyring.New(ui), nil }, - "kv": func() (cli.Command, error) { return kv.New(), nil }, - "kv delete": func() (cli.Command, error) { return kvdel.New(ui), nil }, - "kv export": func() (cli.Command, error) { return kvexp.New(ui), nil }, - "kv get": func() (cli.Command, error) { return kvget.New(ui), nil }, - "kv import": func() (cli.Command, error) { return kvimp.New(ui), nil }, - "kv put": func() (cli.Command, error) { return kvput.New(ui), nil }, - "leave": func() (cli.Command, error) { return leave.New(ui), nil }, - "lock": func() (cli.Command, error) { return lock.New(ui), nil }, - "maint": func() (cli.Command, error) { return maint.New(ui), nil }, - "members": func() (cli.Command, error) { return members.New(ui), nil }, - "monitor": func() (cli.Command, error) { return monitor.New(ui, makeShutdownCh()), nil }, - "operator": func() (cli.Command, error) { return operator.New(), nil }, - "operator autopilot": func() (cli.Command, error) { return operauto.New(), nil }, - "operator autopilot get-config": func() (cli.Command, error) { return operautoget.New(ui), nil }, - "operator autopilot set-config": func() (cli.Command, error) { return operautoset.New(ui), nil }, - "operator raft": func() (cli.Command, error) { return operraft.New(), nil }, - "operator raft list-peers": func() (cli.Command, error) { return operraftlist.New(ui), nil }, - "operator raft remove-peer": func() (cli.Command, error) { return operraftremove.New(ui), nil }, - "reload": func() (cli.Command, error) { return reload.New(ui), nil }, - "rtt": func() (cli.Command, error) { return rtt.New(ui), nil }, - "snapshot": func() (cli.Command, error) { return snapshot.New(), nil }, - "snapshot inspect": func() (cli.Command, error) { return snapinspect.New(ui), nil }, - "snapshot restore": func() (cli.Command, error) { return snaprestore.New(ui), nil }, - "snapshot save": func() (cli.Command, error) { return snapsave.New(ui), nil }, - "validate": func() (cli.Command, error) { return validate.New(ui), nil }, - "version": func() (cli.Command, error) { return version.New(ui, verHuman), nil }, - "watch": func() (cli.Command, error) { return watch.New(ui, makeShutdownCh()), nil }, - } -} - -// makeShutdownCh returns a channel that can be used for shutdown -// notifications for commands. This channel will send a message for every -// interrupt or SIGTERM received. -func makeShutdownCh() <-chan struct{} { - resultCh := make(chan struct{}) - - signalCh := make(chan os.Signal, 4) - signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) - go func() { - for { - <-signalCh - resultCh <- struct{}{} - } - }() - - return resultCh -} diff --git a/command/commands_oss.go b/command/commands_oss.go new file mode 100644 index 0000000000..43fbeb29c9 --- /dev/null +++ b/command/commands_oss.go @@ -0,0 +1,95 @@ +package command + +import ( + "github.com/hashicorp/consul/command/agent" + "github.com/hashicorp/consul/command/catalog" + catlistdc "github.com/hashicorp/consul/command/catalog/list/dc" + catlistnodes "github.com/hashicorp/consul/command/catalog/list/nodes" + catlistsvc "github.com/hashicorp/consul/command/catalog/list/services" + "github.com/hashicorp/consul/command/event" + "github.com/hashicorp/consul/command/exec" + "github.com/hashicorp/consul/command/forceleave" + "github.com/hashicorp/consul/command/info" + "github.com/hashicorp/consul/command/join" + "github.com/hashicorp/consul/command/keygen" + "github.com/hashicorp/consul/command/keyring" + "github.com/hashicorp/consul/command/kv" + kvdel "github.com/hashicorp/consul/command/kv/del" + kvexp "github.com/hashicorp/consul/command/kv/exp" + kvget "github.com/hashicorp/consul/command/kv/get" + kvimp "github.com/hashicorp/consul/command/kv/imp" + kvput "github.com/hashicorp/consul/command/kv/put" + "github.com/hashicorp/consul/command/leave" + "github.com/hashicorp/consul/command/lock" + "github.com/hashicorp/consul/command/maint" + "github.com/hashicorp/consul/command/members" + "github.com/hashicorp/consul/command/monitor" + "github.com/hashicorp/consul/command/operator" + operauto "github.com/hashicorp/consul/command/operator/autopilot" + operautoget "github.com/hashicorp/consul/command/operator/autopilot/get" + operautoset "github.com/hashicorp/consul/command/operator/autopilot/set" + operraft "github.com/hashicorp/consul/command/operator/raft" + operraftlist "github.com/hashicorp/consul/command/operator/raft/listpeers" + operraftremove "github.com/hashicorp/consul/command/operator/raft/removepeer" + "github.com/hashicorp/consul/command/reload" + "github.com/hashicorp/consul/command/rtt" + "github.com/hashicorp/consul/command/snapshot" + snapinspect "github.com/hashicorp/consul/command/snapshot/inspect" + snaprestore "github.com/hashicorp/consul/command/snapshot/restore" + snapsave "github.com/hashicorp/consul/command/snapshot/save" + "github.com/hashicorp/consul/command/validate" + "github.com/hashicorp/consul/command/version" + "github.com/hashicorp/consul/command/watch" + consulversion "github.com/hashicorp/consul/version" + + "github.com/mitchellh/cli" +) + +func init() { + rev := consulversion.GitCommit + ver := consulversion.Version + verPre := consulversion.VersionPrerelease + verHuman := consulversion.GetHumanVersion() + + Register("agent", func(ui cli.Ui) (cli.Command, error) { + return agent.New(ui, rev, ver, verPre, verHuman, make(chan struct{})), nil + }) + Register("catalog", func(cli.Ui) (cli.Command, error) { return catalog.New(), nil }) + Register("catalog datacenters", func(ui cli.Ui) (cli.Command, error) { return catlistdc.New(ui), nil }) + Register("catalog nodes", func(ui cli.Ui) (cli.Command, error) { return catlistnodes.New(ui), nil }) + Register("catalog services", func(ui cli.Ui) (cli.Command, error) { return catlistsvc.New(ui), nil }) + Register("event", func(ui cli.Ui) (cli.Command, error) { return event.New(ui), nil }) + Register("exec", func(ui cli.Ui) (cli.Command, error) { return exec.New(ui, MakeShutdownCh()), nil }) + Register("force-leave", func(ui cli.Ui) (cli.Command, error) { return forceleave.New(ui), nil }) + Register("info", func(ui cli.Ui) (cli.Command, error) { return info.New(ui), nil }) + Register("join", func(ui cli.Ui) (cli.Command, error) { return join.New(ui), nil }) + Register("keygen", func(ui cli.Ui) (cli.Command, error) { return keygen.New(ui), nil }) + Register("keyring", func(ui cli.Ui) (cli.Command, error) { return keyring.New(ui), nil }) + Register("kv", func(cli.Ui) (cli.Command, error) { return kv.New(), nil }) + Register("kv delete", func(ui cli.Ui) (cli.Command, error) { return kvdel.New(ui), nil }) + Register("kv export", func(ui cli.Ui) (cli.Command, error) { return kvexp.New(ui), nil }) + Register("kv get", func(ui cli.Ui) (cli.Command, error) { return kvget.New(ui), nil }) + Register("kv import", func(ui cli.Ui) (cli.Command, error) { return kvimp.New(ui), nil }) + Register("kv put", func(ui cli.Ui) (cli.Command, error) { return kvput.New(ui), nil }) + Register("leave", func(ui cli.Ui) (cli.Command, error) { return leave.New(ui), nil }) + Register("lock", func(ui cli.Ui) (cli.Command, error) { return lock.New(ui), nil }) + Register("maint", func(ui cli.Ui) (cli.Command, error) { return maint.New(ui), nil }) + Register("members", func(ui cli.Ui) (cli.Command, error) { return members.New(ui), nil }) + Register("monitor", func(ui cli.Ui) (cli.Command, error) { return monitor.New(ui, MakeShutdownCh()), nil }) + Register("operator", func(cli.Ui) (cli.Command, error) { return operator.New(), nil }) + Register("operator autopilot", func(cli.Ui) (cli.Command, error) { return operauto.New(), nil }) + Register("operator autopilot get-config", func(ui cli.Ui) (cli.Command, error) { return operautoget.New(ui), nil }) + Register("operator autopilot set-config", func(ui cli.Ui) (cli.Command, error) { return operautoset.New(ui), nil }) + Register("operator raft", func(cli.Ui) (cli.Command, error) { return operraft.New(), nil }) + Register("operator raft list-peers", func(ui cli.Ui) (cli.Command, error) { return operraftlist.New(ui), nil }) + Register("operator raft remove-peer", func(ui cli.Ui) (cli.Command, error) { return operraftremove.New(ui), nil }) + Register("reload", func(ui cli.Ui) (cli.Command, error) { return reload.New(ui), nil }) + Register("rtt", func(ui cli.Ui) (cli.Command, error) { return rtt.New(ui), nil }) + Register("snapshot", func(cli.Ui) (cli.Command, error) { return snapshot.New(), nil }) + Register("snapshot inspect", func(ui cli.Ui) (cli.Command, error) { return snapinspect.New(ui), nil }) + Register("snapshot restore", func(ui cli.Ui) (cli.Command, error) { return snaprestore.New(ui), nil }) + Register("snapshot save", func(ui cli.Ui) (cli.Command, error) { return snapsave.New(ui), nil }) + Register("validate", func(ui cli.Ui) (cli.Command, error) { return validate.New(ui), nil }) + Register("version", func(ui cli.Ui) (cli.Command, error) { return version.New(ui, verHuman), nil }) + Register("watch", func(ui cli.Ui) (cli.Command, error) { return watch.New(ui, MakeShutdownCh()), nil }) +} diff --git a/command/registry.go b/command/registry.go new file mode 100644 index 0000000000..2b092ae722 --- /dev/null +++ b/command/registry.go @@ -0,0 +1,60 @@ +package command + +import ( + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/mitchellh/cli" +) + +// Factory is a function that returns a new instance of a CLI-sub command. +type Factory func(cli.Ui) (cli.Command, error) + +// Register adds a new CLI sub-command to the registry. +func Register(name string, fn Factory) { + if registry == nil { + registry = make(map[string]Factory) + } + + if registry[name] != nil { + panic(fmt.Errorf("Command %q is already registered", name)) + } + registry[name] = fn +} + +// Map returns a realized mapping of available CLI commands in a format that +// the CLI class can consume. This should be called after all registration is +// complete. +func Map(ui cli.Ui) map[string]cli.CommandFactory { + m := make(map[string]cli.CommandFactory) + for name, fn := range registry { + thisFn := fn + m[name] = func() (cli.Command, error) { + return thisFn(ui) + } + } + return m +} + +// registry has an entry for each available CLI sub-command, indexed by sub +// command name. This should be populated at package init() time via Register(). +var registry map[string]Factory + +// MakeShutdownCh returns a channel that can be used for shutdown notifications +// for commands. This channel will send a message for every interrupt or SIGTERM +// received. +func MakeShutdownCh() <-chan struct{} { + resultCh := make(chan struct{}) + signalCh := make(chan os.Signal, 4) + signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) + go func() { + for { + <-signalCh + resultCh <- struct{}{} + } + }() + + return resultCh +} diff --git a/main.go b/main.go index f91694c94e..855967af47 100644 --- a/main.go +++ b/main.go @@ -34,17 +34,19 @@ func realMain() int { } } - var cmds []string - for c := range command.Commands { - cmds = append(cmds, c) + ui := &cli.BasicUi{Writer: os.Stdout, ErrorWriter: os.Stderr} + cmds := command.Map(ui) + var names []string + for c := range cmds { + names = append(names, c) } cli := &cli.CLI{ Args: args, - Commands: command.Commands, + Commands: cmds, Autocomplete: true, Name: "consul", - HelpFunc: cli.FilteredHelpFunc(cmds, cli.BasicHelpFunc("consul")), + HelpFunc: cli.FilteredHelpFunc(names, cli.BasicHelpFunc("consul")), } exitCode, err := cli.Run() From 29367cd5aeff9293da0b79c3d86a0e76dcf7c27d Mon Sep 17 00:00:00 2001 From: James Phillips Date: Tue, 28 Nov 2017 13:47:30 -0800 Subject: [PATCH 05/14] Moves ACL disabled response logic down into endpoints. This lets us make the registration of endpoints less fancy, on the road to adding a registration mechanism. --- agent/acl_endpoint.go | 35 +++++++++++++++++++++++++++++--- agent/acl_endpoint_test.go | 39 ++++++++++++++++++++++++++++++++++++ agent/agent_endpoint.go | 3 +++ agent/agent_endpoint_test.go | 4 ++++ agent/http.go | 30 +++++++++------------------ 5 files changed, 87 insertions(+), 24 deletions(-) diff --git a/agent/acl_endpoint.go b/agent/acl_endpoint.go index c30f5f55ef..9a40a6596c 100644 --- a/agent/acl_endpoint.go +++ b/agent/acl_endpoint.go @@ -14,16 +14,24 @@ type aclCreateResponse struct { ID string } -// ACLDisabled handles if ACL datacenter is not configured -func ACLDisabled(resp http.ResponseWriter, req *http.Request) (interface{}, error) { +// checkACLDisabled will return a standard response if ACLs are disabled. This +// returns true if they are disabled and we should not continue. +func (s *HTTPServer) checkACLDisabled(resp http.ResponseWriter, req *http.Request) bool { + if s.agent.config.ACLDatacenter != "" { + return false + } + resp.WriteHeader(http.StatusUnauthorized) fmt.Fprint(resp, "ACL support disabled") - return nil, nil + return true } // ACLBootstrap is used to perform a one-time ACL bootstrap operation on // a cluster to get the first management token. func (s *HTTPServer) ACLBootstrap(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } if req.Method != "PUT" { return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} } @@ -48,6 +56,9 @@ func (s *HTTPServer) ACLBootstrap(resp http.ResponseWriter, req *http.Request) ( } func (s *HTTPServer) ACLDestroy(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } if req.Method != "PUT" { return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} } @@ -74,6 +85,9 @@ func (s *HTTPServer) ACLDestroy(resp http.ResponseWriter, req *http.Request) (in } func (s *HTTPServer) ACLCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } if req.Method != "PUT" { return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} } @@ -81,6 +95,9 @@ func (s *HTTPServer) ACLCreate(resp http.ResponseWriter, req *http.Request) (int } func (s *HTTPServer) ACLUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } if req.Method != "PUT" { return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} } @@ -129,6 +146,9 @@ func (s *HTTPServer) aclSet(resp http.ResponseWriter, req *http.Request, update } func (s *HTTPServer) ACLClone(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } if req.Method != "PUT" { return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} } @@ -181,6 +201,9 @@ func (s *HTTPServer) ACLClone(resp http.ResponseWriter, req *http.Request) (inte } func (s *HTTPServer) ACLGet(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } if req.Method != "GET" { return nil, MethodNotAllowedError{req.Method, []string{"GET"}} } @@ -215,6 +238,9 @@ func (s *HTTPServer) ACLGet(resp http.ResponseWriter, req *http.Request) (interf } func (s *HTTPServer) ACLList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } if req.Method != "GET" { return nil, MethodNotAllowedError{req.Method, []string{"GET"}} } @@ -241,6 +267,9 @@ func (s *HTTPServer) ACLList(resp http.ResponseWriter, req *http.Request) (inter } func (s *HTTPServer) ACLReplicationStatus(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } if req.Method != "GET" { return nil, MethodNotAllowedError{req.Method, []string{"GET"}} } diff --git a/agent/acl_endpoint_test.go b/agent/acl_endpoint_test.go index 9aba5154c3..aa15c48686 100644 --- a/agent/acl_endpoint_test.go +++ b/agent/acl_endpoint_test.go @@ -3,14 +3,53 @@ package agent import ( "bytes" "encoding/json" + "fmt" "net/http" "net/http/httptest" + "strings" "testing" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" ) +func TestACL_Disabled_Response(t *testing.T) { + t.Parallel() + a := NewTestAgent(t.Name(), "") + defer a.Shutdown() + + tests := []func(resp http.ResponseWriter, req *http.Request) (interface{}, error){ + a.srv.ACLBootstrap, + a.srv.ACLDestroy, + a.srv.ACLCreate, + a.srv.ACLUpdate, + a.srv.ACLClone, + a.srv.ACLGet, + a.srv.ACLList, + a.srv.ACLReplicationStatus, + a.srv.AgentToken, // See TestAgent_Token. + } + for i, tt := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + req, _ := http.NewRequest("PUT", "/should/not/care", nil) + resp := httptest.NewRecorder() + obj, err := tt(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if obj != nil { + t.Fatalf("bad: %#v", obj) + } + if got, want := resp.Code, http.StatusUnauthorized; got != want { + t.Fatalf("got %d want %d", got, want) + } + if !strings.Contains(resp.Body.String(), "ACL support disabled") { + t.Fatalf("bad: %#v", resp) + } + }) + } +} + func makeTestACL(t *testing.T, srv *HTTPServer) string { body := bytes.NewBuffer(nil) enc := json.NewEncoder(body) diff --git a/agent/agent_endpoint.go b/agent/agent_endpoint.go index bf74827234..39b38294dd 100644 --- a/agent/agent_endpoint.go +++ b/agent/agent_endpoint.go @@ -809,6 +809,9 @@ func (h *httpLogHandler) HandleLog(log string) { } func (s *HTTPServer) AgentToken(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } if req.Method != "PUT" { return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} } diff --git a/agent/agent_endpoint_test.go b/agent/agent_endpoint_test.go index 3426b3bac1..427154f9e1 100644 --- a/agent/agent_endpoint_test.go +++ b/agent/agent_endpoint_test.go @@ -1771,6 +1771,10 @@ func TestAgent_Monitor_ACLDeny(t *testing.T) { func TestAgent_Token(t *testing.T) { t.Parallel() + + // The behavior of this handler when ACLs are disabled is vetted over + // in TestACL_Disabled_Response since there's already good infra set + // up over there to test this, and it calls the common function. a := NewTestAgent(t.Name(), TestACLConfig()+` acl_token = "" acl_agent_token = "" diff --git a/agent/http.go b/agent/http.go index a51e2c944d..ca3085fd42 100644 --- a/agent/http.go +++ b/agent/http.go @@ -77,27 +77,15 @@ func (s *HTTPServer) handler(enableDebug bool) http.Handler { mux.HandleFunc("/", s.Index) // API V1. - if s.agent.config.ACLDatacenter != "" { - handleFuncMetrics("/v1/acl/bootstrap", s.wrap(s.ACLBootstrap)) - handleFuncMetrics("/v1/acl/create", s.wrap(s.ACLCreate)) - handleFuncMetrics("/v1/acl/update", s.wrap(s.ACLUpdate)) - handleFuncMetrics("/v1/acl/destroy/", s.wrap(s.ACLDestroy)) - handleFuncMetrics("/v1/acl/info/", s.wrap(s.ACLGet)) - handleFuncMetrics("/v1/acl/clone/", s.wrap(s.ACLClone)) - handleFuncMetrics("/v1/acl/list", s.wrap(s.ACLList)) - handleFuncMetrics("/v1/acl/replication", s.wrap(s.ACLReplicationStatus)) - handleFuncMetrics("/v1/agent/token/", s.wrap(s.AgentToken)) - } else { - handleFuncMetrics("/v1/acl/bootstrap", s.wrap(ACLDisabled)) - handleFuncMetrics("/v1/acl/create", s.wrap(ACLDisabled)) - handleFuncMetrics("/v1/acl/update", s.wrap(ACLDisabled)) - handleFuncMetrics("/v1/acl/destroy/", s.wrap(ACLDisabled)) - handleFuncMetrics("/v1/acl/info/", s.wrap(ACLDisabled)) - handleFuncMetrics("/v1/acl/clone/", s.wrap(ACLDisabled)) - handleFuncMetrics("/v1/acl/list", s.wrap(ACLDisabled)) - handleFuncMetrics("/v1/acl/replication", s.wrap(ACLDisabled)) - handleFuncMetrics("/v1/agent/token/", s.wrap(ACLDisabled)) - } + handleFuncMetrics("/v1/acl/bootstrap", s.wrap(s.ACLBootstrap)) + handleFuncMetrics("/v1/acl/create", s.wrap(s.ACLCreate)) + handleFuncMetrics("/v1/acl/update", s.wrap(s.ACLUpdate)) + handleFuncMetrics("/v1/acl/destroy/", s.wrap(s.ACLDestroy)) + handleFuncMetrics("/v1/acl/info/", s.wrap(s.ACLGet)) + handleFuncMetrics("/v1/acl/clone/", s.wrap(s.ACLClone)) + handleFuncMetrics("/v1/acl/list", s.wrap(s.ACLList)) + handleFuncMetrics("/v1/acl/replication", s.wrap(s.ACLReplicationStatus)) + handleFuncMetrics("/v1/agent/token/", s.wrap(s.AgentToken)) handleFuncMetrics("/v1/agent/self", s.wrap(s.AgentSelf)) handleFuncMetrics("/v1/agent/maintenance", s.wrap(s.AgentNodeMaintenance)) handleFuncMetrics("/v1/agent/reload", s.wrap(s.AgentReload)) From 679775418f8afd61e18a9ccd5598dfd6c357166c Mon Sep 17 00:00:00 2001 From: James Phillips Date: Tue, 28 Nov 2017 13:57:45 -0800 Subject: [PATCH 06/14] Moves coordinate disabled logic down into endpoints. Similar rationale to the previous change for ACLs. --- agent/coordinate_endpoint.go | 24 +++++++++++++++++---- agent/coordinate_endpoint_test.go | 36 +++++++++++++++++++++++++++++++ agent/http.go | 15 ++++--------- 3 files changed, 60 insertions(+), 15 deletions(-) diff --git a/agent/coordinate_endpoint.go b/agent/coordinate_endpoint.go index 1a16964fe2..4d51a9b0fb 100644 --- a/agent/coordinate_endpoint.go +++ b/agent/coordinate_endpoint.go @@ -9,12 +9,16 @@ import ( "github.com/hashicorp/consul/agent/structs" ) -// coordinateDisabled handles all the endpoints when coordinates are not enabled, -// returning an error message. -func coordinateDisabled(resp http.ResponseWriter, req *http.Request) (interface{}, error) { +// checkCoordinateDisabled will return a standard response if coordinates are +// disabled. This returns true if they are disabled and we should not continue. +func (s *HTTPServer) checkCoordinateDisabled(resp http.ResponseWriter, req *http.Request) bool { + if !s.agent.config.DisableCoordinates { + return false + } + resp.WriteHeader(http.StatusUnauthorized) fmt.Fprint(resp, "Coordinate support disabled") - return nil, nil + return true } // sorter wraps a coordinate list and implements the sort.Interface to sort by @@ -41,6 +45,9 @@ func (s *sorter) Less(i, j int) bool { // CoordinateDatacenters returns the WAN nodes in each datacenter, along with // raw network coordinates. func (s *HTTPServer) CoordinateDatacenters(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkCoordinateDisabled(resp, req) { + return nil, nil + } if req.Method != "GET" { return nil, MethodNotAllowedError{req.Method, []string{"GET"}} } @@ -70,6 +77,9 @@ func (s *HTTPServer) CoordinateDatacenters(resp http.ResponseWriter, req *http.R // CoordinateNodes returns the LAN nodes in the given datacenter, along with // raw network coordinates. func (s *HTTPServer) CoordinateNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkCoordinateDisabled(resp, req) { + return nil, nil + } if req.Method != "GET" { return nil, MethodNotAllowedError{req.Method, []string{"GET"}} } @@ -92,6 +102,9 @@ func (s *HTTPServer) CoordinateNodes(resp http.ResponseWriter, req *http.Request // CoordinateNode returns the LAN node in the given datacenter, along with // raw network coordinates. func (s *HTTPServer) CoordinateNode(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkCoordinateDisabled(resp, req) { + return nil, nil + } if req.Method != "GET" { return nil, MethodNotAllowedError{req.Method, []string{"GET"}} } @@ -141,6 +154,9 @@ func filterCoordinates(req *http.Request, in structs.Coordinates) structs.Coordi // CoordinateUpdate inserts or updates the LAN coordinate of a node. func (s *HTTPServer) CoordinateUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkCoordinateDisabled(resp, req) { + return nil, nil + } if req.Method != "PUT" { return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} } diff --git a/agent/coordinate_endpoint_test.go b/agent/coordinate_endpoint_test.go index 09001ccabe..deb812a9b9 100644 --- a/agent/coordinate_endpoint_test.go +++ b/agent/coordinate_endpoint_test.go @@ -1,8 +1,10 @@ package agent import ( + "fmt" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -10,6 +12,40 @@ import ( "github.com/hashicorp/serf/coordinate" ) +func TestCoordinate_Disabled_Response(t *testing.T) { + t.Parallel() + a := NewTestAgent(t.Name(), ` + disable_coordinates = true +`) + defer a.Shutdown() + + tests := []func(resp http.ResponseWriter, req *http.Request) (interface{}, error){ + a.srv.CoordinateDatacenters, + a.srv.CoordinateNodes, + a.srv.CoordinateNode, + a.srv.CoordinateUpdate, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + req, _ := http.NewRequest("PUT", "/should/not/care", nil) + resp := httptest.NewRecorder() + obj, err := tt(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if obj != nil { + t.Fatalf("bad: %#v", obj) + } + if got, want := resp.Code, http.StatusUnauthorized; got != want { + t.Fatalf("got %d want %d", got, want) + } + if !strings.Contains(resp.Body.String(), "Coordinate support disabled") { + t.Fatalf("bad: %#v", resp) + } + }) + } +} + func TestCoordinate_Datacenters(t *testing.T) { t.Parallel() a := NewTestAgent(t.Name(), "") diff --git a/agent/http.go b/agent/http.go index ca3085fd42..ed6f350922 100644 --- a/agent/http.go +++ b/agent/http.go @@ -113,17 +113,10 @@ func (s *HTTPServer) handler(enableDebug bool) http.Handler { handleFuncMetrics("/v1/catalog/services", s.wrap(s.CatalogServices)) handleFuncMetrics("/v1/catalog/service/", s.wrap(s.CatalogServiceNodes)) handleFuncMetrics("/v1/catalog/node/", s.wrap(s.CatalogNodeServices)) - if !s.agent.config.DisableCoordinates { - handleFuncMetrics("/v1/coordinate/datacenters", s.wrap(s.CoordinateDatacenters)) - handleFuncMetrics("/v1/coordinate/nodes", s.wrap(s.CoordinateNodes)) - handleFuncMetrics("/v1/coordinate/node/", s.wrap(s.CoordinateNode)) - handleFuncMetrics("/v1/coordinate/update", s.wrap(s.CoordinateUpdate)) - } else { - handleFuncMetrics("/v1/coordinate/datacenters", s.wrap(coordinateDisabled)) - handleFuncMetrics("/v1/coordinate/nodes", s.wrap(coordinateDisabled)) - handleFuncMetrics("/v1/coordinate/node/", s.wrap(coordinateDisabled)) - handleFuncMetrics("/v1/coordinate/update", s.wrap(coordinateDisabled)) - } + handleFuncMetrics("/v1/coordinate/datacenters", s.wrap(s.CoordinateDatacenters)) + handleFuncMetrics("/v1/coordinate/nodes", s.wrap(s.CoordinateNodes)) + handleFuncMetrics("/v1/coordinate/node/", s.wrap(s.CoordinateNode)) + handleFuncMetrics("/v1/coordinate/update", s.wrap(s.CoordinateUpdate)) handleFuncMetrics("/v1/event/fire/", s.wrap(s.EventFire)) handleFuncMetrics("/v1/event/list", s.wrap(s.EventList)) handleFuncMetrics("/v1/health/node/", s.wrap(s.HealthNodeChecks)) From 68f100c8df9ee4734a45c2bbf70c3f0c85f9165a Mon Sep 17 00:00:00 2001 From: James Phillips Date: Tue, 28 Nov 2017 16:06:26 -0800 Subject: [PATCH 07/14] Creates HTTP endpoint registry. --- agent/http.go | 103 +++++++++++-------------------------- agent/http_oss.go | 71 ++++++++++++++++++++++++++ agent/http_oss_test.go | 113 +++++++++++++++++++++++++++++++++++++++++ agent/http_test.go | 110 --------------------------------------- 4 files changed, 215 insertions(+), 182 deletions(-) create mode 100644 agent/http_oss.go create mode 100644 agent/http_oss_test.go diff --git a/agent/http.go b/agent/http.go index ed6f350922..930f003472 100644 --- a/agent/http.go +++ b/agent/http.go @@ -39,6 +39,29 @@ type HTTPServer struct { proto string } +// endpoint is a Consul-specific HTTP handler that takes the usual arguments in +// but returns a response object and error, both of which are handled in a +// common manner by Consul's HTTP server. +type endpoint func(resp http.ResponseWriter, req *http.Request) (interface{}, error) + +// unboundEndpoint is an endpoint method on a server. +type unboundEndpoint func(s *HTTPServer, resp http.ResponseWriter, req *http.Request) (interface{}, error) + +// endpoints is a map from URL pattern to unbound endpoint. +var endpoints map[string]unboundEndpoint + +// registerEndpoint registers a new endpoint, which should be done at package +// init() time. +func registerEndpoint(pattern string, fn unboundEndpoint) { + if endpoints == nil { + endpoints = make(map[string]unboundEndpoint) + } + if endpoints[pattern] != nil { + panic(fmt.Errorf("Pattern %q is already registered", pattern)) + } + endpoints[pattern] = fn +} + // handler is used to attach our handlers to the mux func (s *HTTPServer) handler(enableDebug bool) http.Handler { mux := http.NewServeMux() @@ -75,77 +98,13 @@ func (s *HTTPServer) handler(enableDebug bool) http.Handler { } mux.HandleFunc("/", s.Index) - - // API V1. - handleFuncMetrics("/v1/acl/bootstrap", s.wrap(s.ACLBootstrap)) - handleFuncMetrics("/v1/acl/create", s.wrap(s.ACLCreate)) - handleFuncMetrics("/v1/acl/update", s.wrap(s.ACLUpdate)) - handleFuncMetrics("/v1/acl/destroy/", s.wrap(s.ACLDestroy)) - handleFuncMetrics("/v1/acl/info/", s.wrap(s.ACLGet)) - handleFuncMetrics("/v1/acl/clone/", s.wrap(s.ACLClone)) - handleFuncMetrics("/v1/acl/list", s.wrap(s.ACLList)) - handleFuncMetrics("/v1/acl/replication", s.wrap(s.ACLReplicationStatus)) - handleFuncMetrics("/v1/agent/token/", s.wrap(s.AgentToken)) - handleFuncMetrics("/v1/agent/self", s.wrap(s.AgentSelf)) - handleFuncMetrics("/v1/agent/maintenance", s.wrap(s.AgentNodeMaintenance)) - handleFuncMetrics("/v1/agent/reload", s.wrap(s.AgentReload)) - handleFuncMetrics("/v1/agent/monitor", s.wrap(s.AgentMonitor)) - handleFuncMetrics("/v1/agent/metrics", s.wrap(s.AgentMetrics)) - handleFuncMetrics("/v1/agent/services", s.wrap(s.AgentServices)) - handleFuncMetrics("/v1/agent/checks", s.wrap(s.AgentChecks)) - handleFuncMetrics("/v1/agent/members", s.wrap(s.AgentMembers)) - handleFuncMetrics("/v1/agent/join/", s.wrap(s.AgentJoin)) - handleFuncMetrics("/v1/agent/leave", s.wrap(s.AgentLeave)) - handleFuncMetrics("/v1/agent/force-leave/", s.wrap(s.AgentForceLeave)) - handleFuncMetrics("/v1/agent/check/register", s.wrap(s.AgentRegisterCheck)) - handleFuncMetrics("/v1/agent/check/deregister/", s.wrap(s.AgentDeregisterCheck)) - handleFuncMetrics("/v1/agent/check/pass/", s.wrap(s.AgentCheckPass)) - handleFuncMetrics("/v1/agent/check/warn/", s.wrap(s.AgentCheckWarn)) - handleFuncMetrics("/v1/agent/check/fail/", s.wrap(s.AgentCheckFail)) - handleFuncMetrics("/v1/agent/check/update/", s.wrap(s.AgentCheckUpdate)) - handleFuncMetrics("/v1/agent/service/register", s.wrap(s.AgentRegisterService)) - handleFuncMetrics("/v1/agent/service/deregister/", s.wrap(s.AgentDeregisterService)) - handleFuncMetrics("/v1/agent/service/maintenance/", s.wrap(s.AgentServiceMaintenance)) - handleFuncMetrics("/v1/catalog/register", s.wrap(s.CatalogRegister)) - handleFuncMetrics("/v1/catalog/deregister", s.wrap(s.CatalogDeregister)) - handleFuncMetrics("/v1/catalog/datacenters", s.wrap(s.CatalogDatacenters)) - handleFuncMetrics("/v1/catalog/nodes", s.wrap(s.CatalogNodes)) - handleFuncMetrics("/v1/catalog/services", s.wrap(s.CatalogServices)) - handleFuncMetrics("/v1/catalog/service/", s.wrap(s.CatalogServiceNodes)) - handleFuncMetrics("/v1/catalog/node/", s.wrap(s.CatalogNodeServices)) - handleFuncMetrics("/v1/coordinate/datacenters", s.wrap(s.CoordinateDatacenters)) - handleFuncMetrics("/v1/coordinate/nodes", s.wrap(s.CoordinateNodes)) - handleFuncMetrics("/v1/coordinate/node/", s.wrap(s.CoordinateNode)) - handleFuncMetrics("/v1/coordinate/update", s.wrap(s.CoordinateUpdate)) - handleFuncMetrics("/v1/event/fire/", s.wrap(s.EventFire)) - handleFuncMetrics("/v1/event/list", s.wrap(s.EventList)) - handleFuncMetrics("/v1/health/node/", s.wrap(s.HealthNodeChecks)) - handleFuncMetrics("/v1/health/checks/", s.wrap(s.HealthServiceChecks)) - handleFuncMetrics("/v1/health/state/", s.wrap(s.HealthChecksInState)) - handleFuncMetrics("/v1/health/service/", s.wrap(s.HealthServiceNodes)) - handleFuncMetrics("/v1/internal/ui/nodes", s.wrap(s.UINodes)) - handleFuncMetrics("/v1/internal/ui/node/", s.wrap(s.UINodeInfo)) - handleFuncMetrics("/v1/internal/ui/services", s.wrap(s.UIServices)) - handleFuncMetrics("/v1/kv/", s.wrap(s.KVSEndpoint)) - handleFuncMetrics("/v1/operator/raft/configuration", s.wrap(s.OperatorRaftConfiguration)) - handleFuncMetrics("/v1/operator/raft/peer", s.wrap(s.OperatorRaftPeer)) - handleFuncMetrics("/v1/operator/keyring", s.wrap(s.OperatorKeyringEndpoint)) - handleFuncMetrics("/v1/operator/autopilot/configuration", s.wrap(s.OperatorAutopilotConfiguration)) - handleFuncMetrics("/v1/operator/autopilot/health", s.wrap(s.OperatorServerHealth)) - handleFuncMetrics("/v1/query", s.wrap(s.PreparedQueryGeneral)) - handleFuncMetrics("/v1/query/", s.wrap(s.PreparedQuerySpecific)) - handleFuncMetrics("/v1/session/create", s.wrap(s.SessionCreate)) - handleFuncMetrics("/v1/session/destroy/", s.wrap(s.SessionDestroy)) - handleFuncMetrics("/v1/session/renew/", s.wrap(s.SessionRenew)) - handleFuncMetrics("/v1/session/info/", s.wrap(s.SessionGet)) - handleFuncMetrics("/v1/session/node/", s.wrap(s.SessionsForNode)) - handleFuncMetrics("/v1/session/list", s.wrap(s.SessionList)) - handleFuncMetrics("/v1/status/leader", s.wrap(s.StatusLeader)) - handleFuncMetrics("/v1/status/peers", s.wrap(s.StatusPeers)) - handleFuncMetrics("/v1/snapshot", s.wrap(s.Snapshot)) - handleFuncMetrics("/v1/txn", s.wrap(s.Txn)) - - // Debug endpoints. + for pattern, fn := range endpoints { + thisFn := fn + bound := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + return thisFn(s, resp, req) + } + handleFuncMetrics(pattern, s.wrap(bound)) + } if enableDebug { handleFuncMetrics("/debug/pprof/", pprof.Index) handleFuncMetrics("/debug/pprof/cmdline", pprof.Cmdline) @@ -186,7 +145,7 @@ var ( ) // wrap is used to wrap functions to make them more convenient -func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Request) (interface{}, error)) http.HandlerFunc { +func (s *HTTPServer) wrap(handler endpoint) http.HandlerFunc { return func(resp http.ResponseWriter, req *http.Request) { setHeaders(resp, s.agent.config.HTTPResponseHeaders) setTranslateAddr(resp, s.agent.config.TranslateWANAddrs) diff --git a/agent/http_oss.go b/agent/http_oss.go new file mode 100644 index 0000000000..10e9abda04 --- /dev/null +++ b/agent/http_oss.go @@ -0,0 +1,71 @@ +package agent + +func init() { + registerEndpoint("/v1/acl/bootstrap", (*HTTPServer).ACLBootstrap) + registerEndpoint("/v1/acl/create", (*HTTPServer).ACLCreate) + registerEndpoint("/v1/acl/update", (*HTTPServer).ACLUpdate) + registerEndpoint("/v1/acl/destroy/", (*HTTPServer).ACLDestroy) + registerEndpoint("/v1/acl/info/", (*HTTPServer).ACLGet) + registerEndpoint("/v1/acl/clone/", (*HTTPServer).ACLClone) + registerEndpoint("/v1/acl/list", (*HTTPServer).ACLList) + registerEndpoint("/v1/acl/replication", (*HTTPServer).ACLReplicationStatus) + registerEndpoint("/v1/agent/token/", (*HTTPServer).AgentToken) + registerEndpoint("/v1/agent/self", (*HTTPServer).AgentSelf) + registerEndpoint("/v1/agent/maintenance", (*HTTPServer).AgentNodeMaintenance) + registerEndpoint("/v1/agent/reload", (*HTTPServer).AgentReload) + registerEndpoint("/v1/agent/monitor", (*HTTPServer).AgentMonitor) + registerEndpoint("/v1/agent/metrics", (*HTTPServer).AgentMetrics) + registerEndpoint("/v1/agent/services", (*HTTPServer).AgentServices) + registerEndpoint("/v1/agent/checks", (*HTTPServer).AgentChecks) + registerEndpoint("/v1/agent/members", (*HTTPServer).AgentMembers) + registerEndpoint("/v1/agent/join/", (*HTTPServer).AgentJoin) + registerEndpoint("/v1/agent/leave", (*HTTPServer).AgentLeave) + registerEndpoint("/v1/agent/force-leave/", (*HTTPServer).AgentForceLeave) + registerEndpoint("/v1/agent/check/register", (*HTTPServer).AgentRegisterCheck) + registerEndpoint("/v1/agent/check/deregister/", (*HTTPServer).AgentDeregisterCheck) + registerEndpoint("/v1/agent/check/pass/", (*HTTPServer).AgentCheckPass) + registerEndpoint("/v1/agent/check/warn/", (*HTTPServer).AgentCheckWarn) + registerEndpoint("/v1/agent/check/fail/", (*HTTPServer).AgentCheckFail) + registerEndpoint("/v1/agent/check/update/", (*HTTPServer).AgentCheckUpdate) + registerEndpoint("/v1/agent/service/register", (*HTTPServer).AgentRegisterService) + registerEndpoint("/v1/agent/service/deregister/", (*HTTPServer).AgentDeregisterService) + registerEndpoint("/v1/agent/service/maintenance/", (*HTTPServer).AgentServiceMaintenance) + registerEndpoint("/v1/catalog/register", (*HTTPServer).CatalogRegister) + registerEndpoint("/v1/catalog/deregister", (*HTTPServer).CatalogDeregister) + registerEndpoint("/v1/catalog/datacenters", (*HTTPServer).CatalogDatacenters) + registerEndpoint("/v1/catalog/nodes", (*HTTPServer).CatalogNodes) + registerEndpoint("/v1/catalog/services", (*HTTPServer).CatalogServices) + registerEndpoint("/v1/catalog/service/", (*HTTPServer).CatalogServiceNodes) + registerEndpoint("/v1/catalog/node/", (*HTTPServer).CatalogNodeServices) + registerEndpoint("/v1/coordinate/datacenters", (*HTTPServer).CoordinateDatacenters) + registerEndpoint("/v1/coordinate/nodes", (*HTTPServer).CoordinateNodes) + registerEndpoint("/v1/coordinate/node/", (*HTTPServer).CoordinateNode) + registerEndpoint("/v1/coordinate/update", (*HTTPServer).CoordinateUpdate) + registerEndpoint("/v1/event/fire/", (*HTTPServer).EventFire) + registerEndpoint("/v1/event/list", (*HTTPServer).EventList) + registerEndpoint("/v1/health/node/", (*HTTPServer).HealthNodeChecks) + registerEndpoint("/v1/health/checks/", (*HTTPServer).HealthServiceChecks) + registerEndpoint("/v1/health/state/", (*HTTPServer).HealthChecksInState) + registerEndpoint("/v1/health/service/", (*HTTPServer).HealthServiceNodes) + registerEndpoint("/v1/internal/ui/nodes", (*HTTPServer).UINodes) + registerEndpoint("/v1/internal/ui/node/", (*HTTPServer).UINodeInfo) + registerEndpoint("/v1/internal/ui/services", (*HTTPServer).UIServices) + registerEndpoint("/v1/kv/", (*HTTPServer).KVSEndpoint) + registerEndpoint("/v1/operator/raft/configuration", (*HTTPServer).OperatorRaftConfiguration) + registerEndpoint("/v1/operator/raft/peer", (*HTTPServer).OperatorRaftPeer) + registerEndpoint("/v1/operator/keyring", (*HTTPServer).OperatorKeyringEndpoint) + registerEndpoint("/v1/operator/autopilot/configuration", (*HTTPServer).OperatorAutopilotConfiguration) + registerEndpoint("/v1/operator/autopilot/health", (*HTTPServer).OperatorServerHealth) + registerEndpoint("/v1/query", (*HTTPServer).PreparedQueryGeneral) + registerEndpoint("/v1/query/", (*HTTPServer).PreparedQuerySpecific) + registerEndpoint("/v1/session/create", (*HTTPServer).SessionCreate) + registerEndpoint("/v1/session/destroy/", (*HTTPServer).SessionDestroy) + registerEndpoint("/v1/session/renew/", (*HTTPServer).SessionRenew) + registerEndpoint("/v1/session/info/", (*HTTPServer).SessionGet) + registerEndpoint("/v1/session/node/", (*HTTPServer).SessionsForNode) + registerEndpoint("/v1/session/list", (*HTTPServer).SessionList) + registerEndpoint("/v1/status/leader", (*HTTPServer).StatusLeader) + registerEndpoint("/v1/status/peers", (*HTTPServer).StatusPeers) + registerEndpoint("/v1/snapshot", (*HTTPServer).Snapshot) + registerEndpoint("/v1/txn", (*HTTPServer).Txn) +} diff --git a/agent/http_oss_test.go b/agent/http_oss_test.go new file mode 100644 index 0000000000..d2e04419b7 --- /dev/null +++ b/agent/http_oss_test.go @@ -0,0 +1,113 @@ +package agent + +import ( + "fmt" + "net/http" + "strings" + "testing" + + "github.com/hashicorp/consul/logger" +) + +func TestHTTPAPI_MethodNotAllowed_OSS(t *testing.T) { + tests := []struct { + methods, uri string + }{ + {"PUT", "/v1/acl/bootstrap"}, + {"PUT", "/v1/acl/create"}, + {"PUT", "/v1/acl/update"}, + {"PUT", "/v1/acl/destroy/"}, + {"GET", "/v1/acl/info/"}, + {"PUT", "/v1/acl/clone/"}, + {"GET", "/v1/acl/list"}, + {"GET", "/v1/acl/replication"}, + {"PUT", "/v1/agent/token/"}, + {"GET", "/v1/agent/self"}, + {"GET", "/v1/agent/members"}, + {"PUT", "/v1/agent/check/deregister/"}, + {"PUT", "/v1/agent/check/fail/"}, + {"PUT", "/v1/agent/check/pass/"}, + {"PUT", "/v1/agent/check/register"}, + {"PUT", "/v1/agent/check/update/"}, + {"PUT", "/v1/agent/check/warn/"}, + {"GET", "/v1/agent/checks"}, + {"PUT", "/v1/agent/force-leave/"}, + {"PUT", "/v1/agent/join/"}, + {"PUT", "/v1/agent/leave"}, + {"PUT", "/v1/agent/maintenance"}, + {"GET", "/v1/agent/metrics"}, + // {"GET", "/v1/agent/monitor"}, // requires LogWriter. Hangs if LogWriter is provided + {"PUT", "/v1/agent/reload"}, + {"PUT", "/v1/agent/service/deregister/"}, + {"PUT", "/v1/agent/service/maintenance/"}, + {"PUT", "/v1/agent/service/register"}, + {"GET", "/v1/agent/services"}, + {"GET", "/v1/catalog/datacenters"}, + {"PUT", "/v1/catalog/deregister"}, + {"GET", "/v1/catalog/node/"}, + {"GET", "/v1/catalog/nodes"}, + {"PUT", "/v1/catalog/register"}, + {"GET", "/v1/catalog/service/"}, + {"GET", "/v1/catalog/services"}, + {"GET", "/v1/coordinate/datacenters"}, + {"GET", "/v1/coordinate/nodes"}, + {"GET", "/v1/coordinate/node/"}, + {"PUT", "/v1/event/fire/"}, + {"GET", "/v1/event/list"}, + {"GET", "/v1/health/checks/"}, + {"GET", "/v1/health/node/"}, + {"GET", "/v1/health/service/"}, + {"GET", "/v1/health/state/"}, + {"GET", "/v1/internal/ui/node/"}, + {"GET", "/v1/internal/ui/nodes"}, + {"GET", "/v1/internal/ui/services"}, + {"GET PUT DELETE", "/v1/kv/"}, + {"GET PUT", "/v1/operator/autopilot/configuration"}, + {"GET", "/v1/operator/autopilot/health"}, + {"GET POST PUT DELETE", "/v1/operator/keyring"}, + {"GET", "/v1/operator/raft/configuration"}, + {"DELETE", "/v1/operator/raft/peer"}, + {"GET POST", "/v1/query"}, + {"GET PUT DELETE", "/v1/query/"}, + {"GET", "/v1/query/xxx/execute"}, + {"GET", "/v1/query/xxx/explain"}, + {"PUT", "/v1/session/create"}, + {"PUT", "/v1/session/destroy/"}, + {"GET", "/v1/session/info/"}, + {"GET", "/v1/session/list"}, + {"GET", "/v1/session/node/"}, + {"PUT", "/v1/session/renew/"}, + {"GET PUT", "/v1/snapshot"}, + {"GET", "/v1/status/leader"}, + // {"GET", "/v1/status/peers"},// hangs + {"PUT", "/v1/txn"}, + } + + a := NewTestAgent(t.Name(), `acl_datacenter = "dc1"`) + a.Agent.LogWriter = logger.NewLogWriter(512) + defer a.Shutdown() + + all := []string{"GET", "PUT", "POST", "DELETE", "HEAD"} + client := http.Client{} + + for _, tt := range tests { + for _, m := range all { + t.Run(m+" "+tt.uri, func(t *testing.T) { + uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), tt.uri) + req, _ := http.NewRequest(m, uri, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatal("client.Do failed: ", err) + } + + allowed := strings.Contains(tt.methods, m) + if allowed && resp.StatusCode == http.StatusMethodNotAllowed { + t.Fatalf("method allowed: got status code %d want any other code", resp.StatusCode) + } + if !allowed && resp.StatusCode != http.StatusMethodNotAllowed { + t.Fatalf("method not allowed: got status code %d want %d", resp.StatusCode, http.StatusMethodNotAllowed) + } + }) + } + } +} diff --git a/agent/http_test.go b/agent/http_test.go index a80c6df12a..8a347d046e 100644 --- a/agent/http_test.go +++ b/agent/http_test.go @@ -20,7 +20,6 @@ import ( "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" - "github.com/hashicorp/consul/logger" "github.com/hashicorp/consul/testutil" "github.com/hashicorp/go-cleanhttp" "golang.org/x/net/http2" @@ -390,115 +389,6 @@ func TestHTTPAPIResponseHeaders(t *testing.T) { } } -func TestHTTPAPI_MethodNotAllowed(t *testing.T) { - tests := []struct { - methods, uri string - }{ - {"PUT", "/v1/acl/bootstrap"}, - {"PUT", "/v1/acl/create"}, - {"PUT", "/v1/acl/update"}, - {"PUT", "/v1/acl/destroy/"}, - {"GET", "/v1/acl/info/"}, - {"PUT", "/v1/acl/clone/"}, - {"GET", "/v1/acl/list"}, - {"GET", "/v1/acl/replication"}, - {"PUT", "/v1/agent/token/"}, - {"GET", "/v1/agent/self"}, - {"GET", "/v1/agent/members"}, - {"PUT", "/v1/agent/check/deregister/"}, - {"PUT", "/v1/agent/check/fail/"}, - {"PUT", "/v1/agent/check/pass/"}, - {"PUT", "/v1/agent/check/register"}, - {"PUT", "/v1/agent/check/update/"}, - {"PUT", "/v1/agent/check/warn/"}, - {"GET", "/v1/agent/checks"}, - {"PUT", "/v1/agent/force-leave/"}, - {"PUT", "/v1/agent/join/"}, - {"PUT", "/v1/agent/leave"}, - {"PUT", "/v1/agent/maintenance"}, - {"GET", "/v1/agent/metrics"}, - // {"GET", "/v1/agent/monitor"}, // requires LogWriter. Hangs if LogWriter is provided - {"PUT", "/v1/agent/reload"}, - {"PUT", "/v1/agent/service/deregister/"}, - {"PUT", "/v1/agent/service/maintenance/"}, - {"PUT", "/v1/agent/service/register"}, - {"GET", "/v1/agent/services"}, - {"GET", "/v1/catalog/datacenters"}, - {"PUT", "/v1/catalog/deregister"}, - {"GET", "/v1/catalog/node/"}, - {"GET", "/v1/catalog/nodes"}, - {"PUT", "/v1/catalog/register"}, - {"GET", "/v1/catalog/service/"}, - {"GET", "/v1/catalog/services"}, - {"GET", "/v1/coordinate/datacenters"}, - {"GET", "/v1/coordinate/nodes"}, - {"GET", "/v1/coordinate/node/"}, - {"PUT", "/v1/event/fire/"}, - {"GET", "/v1/event/list"}, - {"GET", "/v1/health/checks/"}, - {"GET", "/v1/health/node/"}, - {"GET", "/v1/health/service/"}, - {"GET", "/v1/health/state/"}, - {"GET", "/v1/internal/ui/node/"}, - {"GET", "/v1/internal/ui/nodes"}, - {"GET", "/v1/internal/ui/services"}, - {"GET PUT DELETE", "/v1/kv/"}, - {"GET PUT", "/v1/operator/autopilot/configuration"}, - {"GET", "/v1/operator/autopilot/health"}, - {"GET POST PUT DELETE", "/v1/operator/keyring"}, - {"GET", "/v1/operator/raft/configuration"}, - {"DELETE", "/v1/operator/raft/peer"}, - {"GET POST", "/v1/query"}, - {"GET PUT DELETE", "/v1/query/"}, - {"GET", "/v1/query/xxx/execute"}, - {"GET", "/v1/query/xxx/explain"}, - {"PUT", "/v1/session/create"}, - {"PUT", "/v1/session/destroy/"}, - {"GET", "/v1/session/info/"}, - {"GET", "/v1/session/list"}, - {"GET", "/v1/session/node/"}, - {"PUT", "/v1/session/renew/"}, - {"GET PUT", "/v1/snapshot"}, - {"GET", "/v1/status/leader"}, - // {"GET", "/v1/status/peers"},// hangs - {"PUT", "/v1/txn"}, - - // enterprise only - // {"GET POST", "/v1/operator/area"}, - // {"GET PUT DELETE", "/v1/operator/area/"}, - // {"GET", "/v1/operator/area/xxx/members"}, - } - - a := NewTestAgent(t.Name(), `acl_datacenter = "dc1"`) - a.Agent.LogWriter = logger.NewLogWriter(512) - defer a.Shutdown() - - all := []string{"GET", "PUT", "POST", "DELETE", "HEAD"} - client := http.Client{} - - for _, tt := range tests { - for _, m := range all { - - t.Run(m+" "+tt.uri, func(t *testing.T) { - uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), tt.uri) - req, _ := http.NewRequest(m, uri, nil) - resp, err := client.Do(req) - if err != nil { - t.Fatal("client.Do failed: ", err) - } - - allowed := strings.Contains(tt.methods, m) - if allowed && resp.StatusCode == http.StatusMethodNotAllowed { - t.Fatalf("method allowed: got status code %d want any other code", resp.StatusCode) - } - if !allowed && resp.StatusCode != http.StatusMethodNotAllowed { - t.Fatalf("method not allowed: got status code %d want %d", resp.StatusCode, http.StatusMethodNotAllowed) - } - }) - } - } -} - func TestContentTypeIsJSON(t *testing.T) { t.Parallel() a := NewTestAgent(t.Name(), "") From 93ff33b1be8b69294f00117fc29251110b885aa9 Mon Sep 17 00:00:00 2001 From: James Phillips Date: Tue, 28 Nov 2017 16:30:07 -0800 Subject: [PATCH 08/14] Creates a registration mechanism for RPC endpoints. --- agent/consul/server.go | 58 ++++++++++---------------------------- agent/consul/server_oss.go | 15 ++++++++++ 2 files changed, 30 insertions(+), 43 deletions(-) create mode 100644 agent/consul/server_oss.go diff --git a/agent/consul/server.go b/agent/consul/server.go index 09c13a8e7a..5edc49294b 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -112,9 +112,6 @@ type Server struct { // Connection pool to other consul servers connPool *pool.ConnPool - // Endpoints holds our RPC endpoints - endpoints endpoints - // eventChLAN is used to receive events from the // serf cluster in the datacenter eventChLAN chan serf.Event @@ -218,21 +215,6 @@ type Server struct { shutdownLock sync.Mutex } -// Holds the RPC endpoints -type endpoints struct { - ACL *ACL - Catalog *Catalog - Coordinate *Coordinate - Health *Health - Internal *Internal - KVS *KVS - Operator *Operator - PreparedQuery *PreparedQuery - Session *Session - Status *Status - Txn *Txn -} - func NewServer(config *Config) (*Server, error) { return NewServerLogger(config, nil, new(token.Store)) } @@ -624,33 +606,23 @@ func (s *Server) setupRaft() error { return nil } +// endpointFactory is a function that returns an RPC endpoint bound to the given +// server. +type factory func(s *Server) interface{} + +// endpoints is a list of registered RPC endpoint factories. +var endpoints []factory + +// registerEndpoint registers a new RPC endpoint factory. +func registerEndpoint(fn factory) { + endpoints = append(endpoints, fn) +} + // setupRPC is used to setup the RPC listener func (s *Server) setupRPC(tlsWrap tlsutil.DCWrapper) error { - // Create endpoints - s.endpoints.ACL = &ACL{s} - s.endpoints.Catalog = &Catalog{s} - s.endpoints.Coordinate = NewCoordinate(s) - s.endpoints.Health = &Health{s} - s.endpoints.Internal = &Internal{s} - s.endpoints.KVS = &KVS{s} - s.endpoints.Operator = &Operator{s} - s.endpoints.PreparedQuery = &PreparedQuery{s} - s.endpoints.Session = &Session{s} - s.endpoints.Status = &Status{s} - s.endpoints.Txn = &Txn{s} - - // Register the handlers - s.rpcServer.Register(s.endpoints.ACL) - s.rpcServer.Register(s.endpoints.Catalog) - s.rpcServer.Register(s.endpoints.Coordinate) - s.rpcServer.Register(s.endpoints.Health) - s.rpcServer.Register(s.endpoints.Internal) - s.rpcServer.Register(s.endpoints.KVS) - s.rpcServer.Register(s.endpoints.Operator) - s.rpcServer.Register(s.endpoints.PreparedQuery) - s.rpcServer.Register(s.endpoints.Session) - s.rpcServer.Register(s.endpoints.Status) - s.rpcServer.Register(s.endpoints.Txn) + for _, fn := range endpoints { + s.rpcServer.Register(fn(s)) + } ln, err := net.ListenTCP("tcp", s.config.RPCAddr) if err != nil { diff --git a/agent/consul/server_oss.go b/agent/consul/server_oss.go new file mode 100644 index 0000000000..05c02e46c8 --- /dev/null +++ b/agent/consul/server_oss.go @@ -0,0 +1,15 @@ +package consul + +func init() { + registerEndpoint(func(s *Server) interface{} { return &ACL{s} }) + registerEndpoint(func(s *Server) interface{} { return &Catalog{s} }) + registerEndpoint(func(s *Server) interface{} { return NewCoordinate(s) }) + registerEndpoint(func(s *Server) interface{} { return &Health{s} }) + registerEndpoint(func(s *Server) interface{} { return &Internal{s} }) + registerEndpoint(func(s *Server) interface{} { return &KVS{s} }) + registerEndpoint(func(s *Server) interface{} { return &Operator{s} }) + registerEndpoint(func(s *Server) interface{} { return &PreparedQuery{s} }) + registerEndpoint(func(s *Server) interface{} { return &Session{s} }) + registerEndpoint(func(s *Server) interface{} { return &Status{s} }) + registerEndpoint(func(s *Server) interface{} { return &Txn{s} }) +} From aa61159b745188bbc716b1572b6a143c269fb73b Mon Sep 17 00:00:00 2001 From: James Phillips Date: Tue, 28 Nov 2017 17:03:34 -0800 Subject: [PATCH 09/14] Creates a registration mechanism for schemas. This also splits out the registration into the table-specific source files. --- agent/consul/state/acl.go | 5 + agent/consul/state/autopilot.go | 22 ++ agent/consul/state/catalog.go | 175 +++++++++++ agent/consul/state/coordinate.go | 43 +++ agent/consul/state/kvs.go | 51 ++++ agent/consul/state/prepared_query.go | 45 +++ agent/consul/state/schema.go | 427 +-------------------------- agent/consul/state/session.go | 87 ++++++ 8 files changed, 441 insertions(+), 414 deletions(-) diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index 8911b604d5..283b4b20b0 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -44,6 +44,11 @@ func aclsBootstrapTableSchema() *memdb.TableSchema { } } +func init() { + registerSchema(aclsTableSchema) + registerSchema(aclsBootstrapTableSchema) +} + // ACLs is used to pull all the ACLs from the snapshot. func (s *Snapshot) ACLs() (memdb.ResultIterator, error) { iter, err := s.tx.Get("acls", "id") diff --git a/agent/consul/state/autopilot.go b/agent/consul/state/autopilot.go index 89f81a9848..21514e5be2 100644 --- a/agent/consul/state/autopilot.go +++ b/agent/consul/state/autopilot.go @@ -7,6 +7,28 @@ import ( "github.com/hashicorp/go-memdb" ) +// autopilotConfigTableSchema returns a new table schema used for storing +// the autopilot configuration +func autopilotConfigTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "autopilot-config", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: true, + Unique: true, + Indexer: &memdb.ConditionalIndex{ + Conditional: func(obj interface{}) (bool, error) { return true, nil }, + }, + }, + }, + } +} + +func init() { + registerSchema(autopilotConfigTableSchema) +} + // Autopilot is used to pull the autopilot config from the snapshot. func (s *Snapshot) Autopilot() (*structs.AutopilotConfig, error) { c, err := s.tx.First("autopilot-config", "id") diff --git a/agent/consul/state/catalog.go b/agent/consul/state/catalog.go index 73a1c75da7..9f066d4178 100644 --- a/agent/consul/state/catalog.go +++ b/agent/consul/state/catalog.go @@ -10,6 +10,181 @@ import ( "github.com/hashicorp/go-memdb" ) +// nodesTableSchema returns a new table schema used for storing node +// information. +func nodesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "nodes", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + }, + "uuid": &memdb.IndexSchema{ + Name: "uuid", + AllowMissing: true, + Unique: true, + Indexer: &memdb.UUIDFieldIndex{ + Field: "ID", + }, + }, + "meta": &memdb.IndexSchema{ + Name: "meta", + AllowMissing: true, + Unique: false, + Indexer: &memdb.StringMapFieldIndex{ + Field: "Meta", + Lowercase: false, + }, + }, + }, + } +} + +// servicesTableSchema returns a new table schema used to store information +// about services. +func servicesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "services", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.StringFieldIndex{ + Field: "ServiceID", + Lowercase: true, + }, + }, + }, + }, + "node": &memdb.IndexSchema{ + Name: "node", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + }, + "service": &memdb.IndexSchema{ + Name: "service", + AllowMissing: true, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "ServiceName", + Lowercase: true, + }, + }, + }, + } +} + +// checksTableSchema returns a new table schema used for storing and indexing +// health check information. Health checks have a number of different attributes +// we want to filter by, so this table is a bit more complex. +func checksTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "checks", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.StringFieldIndex{ + Field: "CheckID", + Lowercase: true, + }, + }, + }, + }, + "status": &memdb.IndexSchema{ + Name: "status", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "Status", + Lowercase: false, + }, + }, + "service": &memdb.IndexSchema{ + Name: "service", + AllowMissing: true, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "ServiceName", + Lowercase: true, + }, + }, + "node": &memdb.IndexSchema{ + Name: "node", + AllowMissing: true, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + }, + "node_service_check": &memdb.IndexSchema{ + Name: "node_service_check", + AllowMissing: true, + Unique: false, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.FieldSetIndex{ + Field: "ServiceID", + }, + }, + }, + }, + "node_service": &memdb.IndexSchema{ + Name: "node_service", + AllowMissing: true, + Unique: false, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.StringFieldIndex{ + Field: "ServiceID", + Lowercase: true, + }, + }, + }, + }, + }, + } +} + +func init() { + registerSchema(nodesTableSchema) + registerSchema(servicesTableSchema) + registerSchema(checksTableSchema) +} + const ( // minUUIDLookupLen is used as a minimum length of a node name required before // we test to see if the name is actually a UUID and perform an ID-based node diff --git a/agent/consul/state/coordinate.go b/agent/consul/state/coordinate.go index c810a8151f..95b27fb1c4 100644 --- a/agent/consul/state/coordinate.go +++ b/agent/consul/state/coordinate.go @@ -8,6 +8,49 @@ import ( "github.com/hashicorp/go-memdb" ) +// coordinatesTableSchema returns a new table schema used for storing +// network coordinates. +func coordinatesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "coordinates", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.CompoundIndex{ + // AllowMissing is required since we allow + // Segment to be an empty string. + AllowMissing: true, + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.StringFieldIndex{ + Field: "Segment", + Lowercase: true, + }, + }, + }, + }, + "node": &memdb.IndexSchema{ + Name: "node", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + }, + }, + } +} + +func init() { + registerSchema(coordinatesTableSchema) +} + // Coordinates is used to pull all the coordinates from the snapshot. func (s *Snapshot) Coordinates() (memdb.ResultIterator, error) { iter, err := s.tx.Get("coordinates", "id") diff --git a/agent/consul/state/kvs.go b/agent/consul/state/kvs.go index f4103bb04d..070a2fc815 100644 --- a/agent/consul/state/kvs.go +++ b/agent/consul/state/kvs.go @@ -9,6 +9,57 @@ import ( "github.com/hashicorp/go-memdb" ) +// kvsTableSchema returns a new table schema used for storing key/value data for +// Consul's kv store. +func kvsTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "kvs", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Key", + Lowercase: false, + }, + }, + "session": &memdb.IndexSchema{ + Name: "session", + AllowMissing: true, + Unique: false, + Indexer: &memdb.UUIDFieldIndex{ + Field: "Session", + }, + }, + }, + } +} + +// tombstonesTableSchema returns a new table schema used for storing tombstones +// during KV delete operations to prevent the index from sliding backwards. +func tombstonesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "tombstones", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Key", + Lowercase: false, + }, + }, + }, + } +} + +func init() { + registerSchema(kvsTableSchema) + registerSchema(tombstonesTableSchema) +} + // KVs is used to pull the full list of KVS entries for use during snapshots. func (s *Snapshot) KVs() (memdb.ResultIterator, error) { iter, err := s.tx.Get("kvs", "id_prefix") diff --git a/agent/consul/state/prepared_query.go b/agent/consul/state/prepared_query.go index 843c9ba1d0..285355785e 100644 --- a/agent/consul/state/prepared_query.go +++ b/agent/consul/state/prepared_query.go @@ -9,6 +9,51 @@ import ( "github.com/hashicorp/go-memdb" ) +// preparedQueriesTableSchema returns a new table schema used for storing +// prepared queries. +func preparedQueriesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "prepared-queries", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.UUIDFieldIndex{ + Field: "ID", + }, + }, + "name": &memdb.IndexSchema{ + Name: "name", + AllowMissing: true, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Name", + Lowercase: true, + }, + }, + "template": &memdb.IndexSchema{ + Name: "template", + AllowMissing: true, + Unique: true, + Indexer: &PreparedQueryIndex{}, + }, + "session": &memdb.IndexSchema{ + Name: "session", + AllowMissing: true, + Unique: false, + Indexer: &memdb.UUIDFieldIndex{ + Field: "Session", + }, + }, + }, + } +} + +func init() { + registerSchema(preparedQueriesTableSchema) +} + // validUUID is used to check if a given string looks like a UUID var validUUID = regexp.MustCompile(`(?i)^[\da-f]{8}-[\da-f]{4}-[\da-f]{4}-[\da-f]{4}-[\da-f]{12}$`) diff --git a/agent/consul/state/schema.go b/agent/consul/state/schema.go index d7fa449e88..e24b23618c 100644 --- a/agent/consul/state/schema.go +++ b/agent/consul/state/schema.go @@ -10,6 +10,15 @@ import ( // new memdb schema structs for constructing an in-memory db. type schemaFn func() *memdb.TableSchema +// schemas is used to register schemas with the state store. +var schemas []schemaFn + +// registerSchema registers a new schema with the state store. This should +// get called at package init() time. +func registerSchema(fn schemaFn) { + schemas = append(schemas, fn) +} + // stateStoreSchema is used to return the combined schema for // the state store. func stateStoreSchema() *memdb.DBSchema { @@ -18,23 +27,6 @@ func stateStoreSchema() *memdb.DBSchema { Tables: make(map[string]*memdb.TableSchema), } - // Collect the needed schemas - schemas := []schemaFn{ - indexTableSchema, - nodesTableSchema, - servicesTableSchema, - checksTableSchema, - kvsTableSchema, - tombstonesTableSchema, - sessionsTableSchema, - sessionChecksTableSchema, - aclsTableSchema, - aclsBootstrapTableSchema, - coordinatesTableSchema, - preparedQueriesTableSchema, - autopilotConfigTableSchema, - } - // Add the tables to the root schema for _, fn := range schemas { schema := fn() @@ -46,8 +38,8 @@ func stateStoreSchema() *memdb.DBSchema { return db } -// indexTableSchema returns a new table schema used for -// tracking various indexes for the Raft log. +// indexTableSchema returns a new table schema used for tracking various indexes +// for the Raft log. func indexTableSchema() *memdb.TableSchema { return &memdb.TableSchema{ Name: "index", @@ -65,399 +57,6 @@ func indexTableSchema() *memdb.TableSchema { } } -// nodesTableSchema returns a new table schema used for -// storing node information. -func nodesTableSchema() *memdb.TableSchema { - return &memdb.TableSchema{ - Name: "nodes", - Indexes: map[string]*memdb.IndexSchema{ - "id": &memdb.IndexSchema{ - Name: "id", - AllowMissing: false, - Unique: true, - Indexer: &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, - }, - "uuid": &memdb.IndexSchema{ - Name: "uuid", - AllowMissing: true, - Unique: true, - Indexer: &memdb.UUIDFieldIndex{ - Field: "ID", - }, - }, - "meta": &memdb.IndexSchema{ - Name: "meta", - AllowMissing: true, - Unique: false, - Indexer: &memdb.StringMapFieldIndex{ - Field: "Meta", - Lowercase: false, - }, - }, - }, - } -} - -// servicesTableSchema returns a new TableSchema used to -// store information about services. -func servicesTableSchema() *memdb.TableSchema { - return &memdb.TableSchema{ - Name: "services", - Indexes: map[string]*memdb.IndexSchema{ - "id": &memdb.IndexSchema{ - Name: "id", - AllowMissing: false, - Unique: true, - Indexer: &memdb.CompoundIndex{ - Indexes: []memdb.Indexer{ - &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, - &memdb.StringFieldIndex{ - Field: "ServiceID", - Lowercase: true, - }, - }, - }, - }, - "node": &memdb.IndexSchema{ - Name: "node", - AllowMissing: false, - Unique: false, - Indexer: &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, - }, - "service": &memdb.IndexSchema{ - Name: "service", - AllowMissing: true, - Unique: false, - Indexer: &memdb.StringFieldIndex{ - Field: "ServiceName", - Lowercase: true, - }, - }, - }, - } -} - -// checksTableSchema returns a new table schema used for -// storing and indexing health check information. Health -// checks have a number of different attributes we want to -// filter by, so this table is a bit more complex. -func checksTableSchema() *memdb.TableSchema { - return &memdb.TableSchema{ - Name: "checks", - Indexes: map[string]*memdb.IndexSchema{ - "id": &memdb.IndexSchema{ - Name: "id", - AllowMissing: false, - Unique: true, - Indexer: &memdb.CompoundIndex{ - Indexes: []memdb.Indexer{ - &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, - &memdb.StringFieldIndex{ - Field: "CheckID", - Lowercase: true, - }, - }, - }, - }, - "status": &memdb.IndexSchema{ - Name: "status", - AllowMissing: false, - Unique: false, - Indexer: &memdb.StringFieldIndex{ - Field: "Status", - Lowercase: false, - }, - }, - "service": &memdb.IndexSchema{ - Name: "service", - AllowMissing: true, - Unique: false, - Indexer: &memdb.StringFieldIndex{ - Field: "ServiceName", - Lowercase: true, - }, - }, - "node": &memdb.IndexSchema{ - Name: "node", - AllowMissing: true, - Unique: false, - Indexer: &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, - }, - "node_service_check": &memdb.IndexSchema{ - Name: "node_service_check", - AllowMissing: true, - Unique: false, - Indexer: &memdb.CompoundIndex{ - Indexes: []memdb.Indexer{ - &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, - &memdb.FieldSetIndex{ - Field: "ServiceID", - }, - }, - }, - }, - "node_service": &memdb.IndexSchema{ - Name: "node_service", - AllowMissing: true, - Unique: false, - Indexer: &memdb.CompoundIndex{ - Indexes: []memdb.Indexer{ - &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, - &memdb.StringFieldIndex{ - Field: "ServiceID", - Lowercase: true, - }, - }, - }, - }, - }, - } -} - -// kvsTableSchema returns a new table schema used for storing -// key/value data from consul's kv store. -func kvsTableSchema() *memdb.TableSchema { - return &memdb.TableSchema{ - Name: "kvs", - Indexes: map[string]*memdb.IndexSchema{ - "id": &memdb.IndexSchema{ - Name: "id", - AllowMissing: false, - Unique: true, - Indexer: &memdb.StringFieldIndex{ - Field: "Key", - Lowercase: false, - }, - }, - "session": &memdb.IndexSchema{ - Name: "session", - AllowMissing: true, - Unique: false, - Indexer: &memdb.UUIDFieldIndex{ - Field: "Session", - }, - }, - }, - } -} - -// tombstonesTableSchema returns a new table schema used for -// storing tombstones during KV delete operations to prevent -// the index from sliding backwards. -func tombstonesTableSchema() *memdb.TableSchema { - return &memdb.TableSchema{ - Name: "tombstones", - Indexes: map[string]*memdb.IndexSchema{ - "id": &memdb.IndexSchema{ - Name: "id", - AllowMissing: false, - Unique: true, - Indexer: &memdb.StringFieldIndex{ - Field: "Key", - Lowercase: false, - }, - }, - }, - } -} - -// sessionsTableSchema returns a new TableSchema used for -// storing session information. -func sessionsTableSchema() *memdb.TableSchema { - return &memdb.TableSchema{ - Name: "sessions", - Indexes: map[string]*memdb.IndexSchema{ - "id": &memdb.IndexSchema{ - Name: "id", - AllowMissing: false, - Unique: true, - Indexer: &memdb.UUIDFieldIndex{ - Field: "ID", - }, - }, - "node": &memdb.IndexSchema{ - Name: "node", - AllowMissing: false, - Unique: false, - Indexer: &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, - }, - }, - } -} - -// sessionChecksTableSchema returns a new table schema used -// for storing session checks. -func sessionChecksTableSchema() *memdb.TableSchema { - return &memdb.TableSchema{ - Name: "session_checks", - Indexes: map[string]*memdb.IndexSchema{ - "id": &memdb.IndexSchema{ - Name: "id", - AllowMissing: false, - Unique: true, - Indexer: &memdb.CompoundIndex{ - Indexes: []memdb.Indexer{ - &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, - &memdb.StringFieldIndex{ - Field: "CheckID", - Lowercase: true, - }, - &memdb.UUIDFieldIndex{ - Field: "Session", - }, - }, - }, - }, - "node_check": &memdb.IndexSchema{ - Name: "node_check", - AllowMissing: false, - Unique: false, - Indexer: &memdb.CompoundIndex{ - Indexes: []memdb.Indexer{ - &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, - &memdb.StringFieldIndex{ - Field: "CheckID", - Lowercase: true, - }, - }, - }, - }, - "session": &memdb.IndexSchema{ - Name: "session", - AllowMissing: false, - Unique: false, - Indexer: &memdb.UUIDFieldIndex{ - Field: "Session", - }, - }, - }, - } -} - -// coordinatesTableSchema returns a new table schema used for storing -// network coordinates. -func coordinatesTableSchema() *memdb.TableSchema { - return &memdb.TableSchema{ - Name: "coordinates", - Indexes: map[string]*memdb.IndexSchema{ - "id": &memdb.IndexSchema{ - Name: "id", - AllowMissing: false, - Unique: true, - Indexer: &memdb.CompoundIndex{ - // AllowMissing is required since we allow - // Segment to be an empty string. - AllowMissing: true, - Indexes: []memdb.Indexer{ - &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, - &memdb.StringFieldIndex{ - Field: "Segment", - Lowercase: true, - }, - }, - }, - }, - "node": &memdb.IndexSchema{ - Name: "node", - AllowMissing: false, - Unique: false, - Indexer: &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, - }, - }, - } -} - -// preparedQueriesTableSchema returns a new table schema used for storing -// prepared queries. -func preparedQueriesTableSchema() *memdb.TableSchema { - return &memdb.TableSchema{ - Name: "prepared-queries", - Indexes: map[string]*memdb.IndexSchema{ - "id": &memdb.IndexSchema{ - Name: "id", - AllowMissing: false, - Unique: true, - Indexer: &memdb.UUIDFieldIndex{ - Field: "ID", - }, - }, - "name": &memdb.IndexSchema{ - Name: "name", - AllowMissing: true, - Unique: true, - Indexer: &memdb.StringFieldIndex{ - Field: "Name", - Lowercase: true, - }, - }, - "template": &memdb.IndexSchema{ - Name: "template", - AllowMissing: true, - Unique: true, - Indexer: &PreparedQueryIndex{}, - }, - "session": &memdb.IndexSchema{ - Name: "session", - AllowMissing: true, - Unique: false, - Indexer: &memdb.UUIDFieldIndex{ - Field: "Session", - }, - }, - }, - } -} - -// autopilotConfigTableSchema returns a new table schema used for storing -// the autopilot configuration -func autopilotConfigTableSchema() *memdb.TableSchema { - return &memdb.TableSchema{ - Name: "autopilot-config", - Indexes: map[string]*memdb.IndexSchema{ - "id": &memdb.IndexSchema{ - Name: "id", - AllowMissing: true, - Unique: true, - Indexer: &memdb.ConditionalIndex{ - Conditional: func(obj interface{}) (bool, error) { return true, nil }, - }, - }, - }, - } +func init() { + registerSchema(indexTableSchema) } diff --git a/agent/consul/state/session.go b/agent/consul/state/session.go index ed958d827a..9775ff639a 100644 --- a/agent/consul/state/session.go +++ b/agent/consul/state/session.go @@ -9,6 +9,93 @@ import ( "github.com/hashicorp/go-memdb" ) +// sessionsTableSchema returns a new table schema used for storing session +// information. +func sessionsTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "sessions", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.UUIDFieldIndex{ + Field: "ID", + }, + }, + "node": &memdb.IndexSchema{ + Name: "node", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + }, + }, + } +} + +// sessionChecksTableSchema returns a new table schema used for storing session +// checks. +func sessionChecksTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "session_checks", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.StringFieldIndex{ + Field: "CheckID", + Lowercase: true, + }, + &memdb.UUIDFieldIndex{ + Field: "Session", + }, + }, + }, + }, + "node_check": &memdb.IndexSchema{ + Name: "node_check", + AllowMissing: false, + Unique: false, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.StringFieldIndex{ + Field: "CheckID", + Lowercase: true, + }, + }, + }, + }, + "session": &memdb.IndexSchema{ + Name: "session", + AllowMissing: false, + Unique: false, + Indexer: &memdb.UUIDFieldIndex{ + Field: "Session", + }, + }, + }, + } +} + +func init() { + registerSchema(sessionsTableSchema) + registerSchema(sessionChecksTableSchema) +} + // Sessions is used to pull the full list of sessions for use during snapshots. func (s *Snapshot) Sessions() (memdb.ResultIterator, error) { iter, err := s.tx.Get("sessions", "id") From e810697e06d08b79318281c507d3688ecca4d679 Mon Sep 17 00:00:00 2001 From: James Phillips Date: Tue, 28 Nov 2017 17:26:16 -0800 Subject: [PATCH 10/14] Resolves an FSM snapshot TODO. This adds checks for sink write calls before we continue the refactor, which will resolve the other TODO comment we deleted as part of this change. --- agent/consul/fsm.go | 51 +++++++++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/agent/consul/fsm.go b/agent/consul/fsm.go index 0d272546f7..4c67ddae4e 100644 --- a/agent/consul/fsm.go +++ b/agent/consul/fsm.go @@ -15,13 +15,6 @@ import ( "github.com/hashicorp/raft" ) -// TODO (slackpad) - There are two refactors we should do here: -// -// 1. Register the different types from the state store and make the FSM more -// generic, especially around snapshot/restore. Those should really just -// pass the encoder into a WriteSnapshot() kind of method. -// 2. Check all the error return values from all the Write() calls. - // msgpackHandle is a shared handle for encoding/decoding msgpack payloads var msgpackHandle = &codec.MsgpackHandle{} @@ -592,7 +585,9 @@ func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, } // Register the node itself - sink.Write([]byte{byte(structs.RegisterRequestType)}) + if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { + return err + } if err := encoder.Encode(&req); err != nil { return err } @@ -603,7 +598,9 @@ func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, return err } for service := services.Next(); service != nil; service = services.Next() { - sink.Write([]byte{byte(structs.RegisterRequestType)}) + if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { + return err + } req.Service = service.(*structs.ServiceNode).ToNodeService() if err := encoder.Encode(&req); err != nil { return err @@ -617,7 +614,9 @@ func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, return err } for check := checks.Next(); check != nil; check = checks.Next() { - sink.Write([]byte{byte(structs.RegisterRequestType)}) + if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { + return err + } req.Check = check.(*structs.HealthCheck) if err := encoder.Encode(&req); err != nil { return err @@ -633,7 +632,9 @@ func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, return err } for coord := coords.Next(); coord != nil; coord = coords.Next() { - sink.Write([]byte{byte(structs.CoordinateBatchUpdateType)}) + if _, err := sink.Write([]byte{byte(structs.CoordinateBatchUpdateType)}); err != nil { + return err + } updates := structs.Coordinates{coord.(*structs.Coordinate)} if err := encoder.Encode(&updates); err != nil { return err @@ -650,7 +651,9 @@ func (s *consulSnapshot) persistSessions(sink raft.SnapshotSink, } for session := sessions.Next(); session != nil; session = sessions.Next() { - sink.Write([]byte{byte(structs.SessionRequestType)}) + if _, err := sink.Write([]byte{byte(structs.SessionRequestType)}); err != nil { + return err + } if err := encoder.Encode(session.(*structs.Session)); err != nil { return err } @@ -666,7 +669,9 @@ func (s *consulSnapshot) persistACLs(sink raft.SnapshotSink, } for acl := acls.Next(); acl != nil; acl = acls.Next() { - sink.Write([]byte{byte(structs.ACLRequestType)}) + if _, err := sink.Write([]byte{byte(structs.ACLRequestType)}); err != nil { + return err + } if err := encoder.Encode(acl.(*structs.ACL)); err != nil { return err } @@ -677,7 +682,9 @@ func (s *consulSnapshot) persistACLs(sink raft.SnapshotSink, return err } if bs != nil { - sink.Write([]byte{byte(structs.ACLBootstrapRequestType)}) + if _, err := sink.Write([]byte{byte(structs.ACLBootstrapRequestType)}); err != nil { + return err + } if err := encoder.Encode(bs); err != nil { return err } @@ -694,7 +701,9 @@ func (s *consulSnapshot) persistKVs(sink raft.SnapshotSink, } for entry := entries.Next(); entry != nil; entry = entries.Next() { - sink.Write([]byte{byte(structs.KVSRequestType)}) + if _, err := sink.Write([]byte{byte(structs.KVSRequestType)}); err != nil { + return err + } if err := encoder.Encode(entry.(*structs.DirEntry)); err != nil { return err } @@ -710,7 +719,9 @@ func (s *consulSnapshot) persistTombstones(sink raft.SnapshotSink, } for stone := stones.Next(); stone != nil; stone = stones.Next() { - sink.Write([]byte{byte(structs.TombstoneRequestType)}) + if _, err := sink.Write([]byte{byte(structs.TombstoneRequestType)}); err != nil { + return err + } // For historical reasons, these are serialized in the snapshots // as KV entries. We want to keep the snapshot format compatible @@ -737,7 +748,9 @@ func (s *consulSnapshot) persistPreparedQueries(sink raft.SnapshotSink, } for _, query := range queries { - sink.Write([]byte{byte(structs.PreparedQueryRequestType)}) + if _, err := sink.Write([]byte{byte(structs.PreparedQueryRequestType)}); err != nil { + return err + } if err := encoder.Encode(query); err != nil { return err } @@ -755,7 +768,9 @@ func (s *consulSnapshot) persistAutopilot(sink raft.SnapshotSink, return nil } - sink.Write([]byte{byte(structs.AutopilotRequestType)}) + if _, err := sink.Write([]byte{byte(structs.AutopilotRequestType)}); err != nil { + return err + } if err := encoder.Encode(autopilot); err != nil { return err } From 78292662d7e78a14120f7b0dcc1ed4bfed362bd4 Mon Sep 17 00:00:00 2001 From: James Phillips Date: Tue, 28 Nov 2017 18:01:17 -0800 Subject: [PATCH 11/14] Moves the FSM into its own package. This will help make it clearer what happens when we add some registration plumbing for the different operations and snapshots. --- agent/consul/{ => fsm}/fsm.go | 40 ++++++++++---------- agent/consul/{ => fsm}/fsm_test.go | 59 ++++++++++++++++++------------ agent/consul/issue_test.go | 13 ++++++- agent/consul/rpc_test.go | 18 +++++++++ agent/consul/server.go | 7 ++-- agent/consul/session_ttl_test.go | 10 +++++ 6 files changed, 100 insertions(+), 47 deletions(-) rename agent/consul/{ => fsm}/fsm.go (94%) rename agent/consul/{ => fsm}/fsm_test.go (96%) diff --git a/agent/consul/fsm.go b/agent/consul/fsm/fsm.go similarity index 94% rename from agent/consul/fsm.go rename to agent/consul/fsm/fsm.go index 4c67ddae4e..6defe6ca84 100644 --- a/agent/consul/fsm.go +++ b/agent/consul/fsm/fsm.go @@ -1,4 +1,4 @@ -package consul +package fsm import ( "fmt" @@ -18,10 +18,10 @@ import ( // msgpackHandle is a shared handle for encoding/decoding msgpack payloads var msgpackHandle = &codec.MsgpackHandle{} -// consulFSM implements a finite state machine that is used +// FSM implements a finite state machine that is used // along with Raft to provide strong consistency. We implement // this outside the Server to avoid exposing this outside the package. -type consulFSM struct { +type FSM struct { logOutput io.Writer logger *log.Logger path string @@ -50,14 +50,14 @@ type snapshotHeader struct { LastIndex uint64 } -// NewFSM is used to construct a new FSM with a blank state -func NewFSM(gc *state.TombstoneGC, logOutput io.Writer) (*consulFSM, error) { +// New is used to construct a new FSM with a blank state. +func New(gc *state.TombstoneGC, logOutput io.Writer) (*FSM, error) { stateNew, err := state.NewStateStore(gc) if err != nil { return nil, err } - fsm := &consulFSM{ + fsm := &FSM{ logOutput: logOutput, logger: log.New(logOutput, "", log.LstdFlags), state: stateNew, @@ -67,13 +67,13 @@ func NewFSM(gc *state.TombstoneGC, logOutput io.Writer) (*consulFSM, error) { } // State is used to return a handle to the current state -func (c *consulFSM) State() *state.Store { +func (c *FSM) State() *state.Store { c.stateLock.RLock() defer c.stateLock.RUnlock() return c.state } -func (c *consulFSM) Apply(log *raft.Log) interface{} { +func (c *FSM) Apply(log *raft.Log) interface{} { buf := log.Data msgType := structs.MessageType(buf[0]) @@ -116,7 +116,7 @@ func (c *consulFSM) Apply(log *raft.Log) interface{} { } } -func (c *consulFSM) applyRegister(buf []byte, index uint64) interface{} { +func (c *FSM) applyRegister(buf []byte, index uint64) interface{} { defer metrics.MeasureSince([]string{"consul", "fsm", "register"}, time.Now()) defer metrics.MeasureSince([]string{"fsm", "register"}, time.Now()) var req structs.RegisterRequest @@ -132,7 +132,7 @@ func (c *consulFSM) applyRegister(buf []byte, index uint64) interface{} { return nil } -func (c *consulFSM) applyDeregister(buf []byte, index uint64) interface{} { +func (c *FSM) applyDeregister(buf []byte, index uint64) interface{} { defer metrics.MeasureSince([]string{"consul", "fsm", "deregister"}, time.Now()) defer metrics.MeasureSince([]string{"fsm", "deregister"}, time.Now()) var req structs.DeregisterRequest @@ -162,7 +162,7 @@ func (c *consulFSM) applyDeregister(buf []byte, index uint64) interface{} { return nil } -func (c *consulFSM) applyKVSOperation(buf []byte, index uint64) interface{} { +func (c *FSM) applyKVSOperation(buf []byte, index uint64) interface{} { var req structs.KVSRequest if err := structs.Decode(buf, &req); err != nil { panic(fmt.Errorf("failed to decode request: %v", err)) @@ -209,7 +209,7 @@ func (c *consulFSM) applyKVSOperation(buf []byte, index uint64) interface{} { } } -func (c *consulFSM) applySessionOperation(buf []byte, index uint64) interface{} { +func (c *FSM) applySessionOperation(buf []byte, index uint64) interface{} { var req structs.SessionRequest if err := structs.Decode(buf, &req); err != nil { panic(fmt.Errorf("failed to decode request: %v", err)) @@ -232,7 +232,7 @@ func (c *consulFSM) applySessionOperation(buf []byte, index uint64) interface{} } } -func (c *consulFSM) applyACLOperation(buf []byte, index uint64) interface{} { +func (c *FSM) applyACLOperation(buf []byte, index uint64) interface{} { var req structs.ACLRequest if err := structs.Decode(buf, &req); err != nil { panic(fmt.Errorf("failed to decode request: %v", err)) @@ -266,7 +266,7 @@ func (c *consulFSM) applyACLOperation(buf []byte, index uint64) interface{} { } } -func (c *consulFSM) applyTombstoneOperation(buf []byte, index uint64) interface{} { +func (c *FSM) applyTombstoneOperation(buf []byte, index uint64) interface{} { var req structs.TombstoneRequest if err := structs.Decode(buf, &req); err != nil { panic(fmt.Errorf("failed to decode request: %v", err)) @@ -288,7 +288,7 @@ func (c *consulFSM) applyTombstoneOperation(buf []byte, index uint64) interface{ // them in a single underlying transaction. This interface isn't 1:1 with the outer // update interface that the coordinate endpoint exposes, so we made it single // purpose and avoided the opcode convention. -func (c *consulFSM) applyCoordinateBatchUpdate(buf []byte, index uint64) interface{} { +func (c *FSM) applyCoordinateBatchUpdate(buf []byte, index uint64) interface{} { var updates structs.Coordinates if err := structs.Decode(buf, &updates); err != nil { panic(fmt.Errorf("failed to decode batch updates: %v", err)) @@ -303,7 +303,7 @@ func (c *consulFSM) applyCoordinateBatchUpdate(buf []byte, index uint64) interfa // applyPreparedQueryOperation applies the given prepared query operation to the // state store. -func (c *consulFSM) applyPreparedQueryOperation(buf []byte, index uint64) interface{} { +func (c *FSM) applyPreparedQueryOperation(buf []byte, index uint64) interface{} { var req structs.PreparedQueryRequest if err := structs.Decode(buf, &req); err != nil { panic(fmt.Errorf("failed to decode request: %v", err)) @@ -324,7 +324,7 @@ func (c *consulFSM) applyPreparedQueryOperation(buf []byte, index uint64) interf } } -func (c *consulFSM) applyTxn(buf []byte, index uint64) interface{} { +func (c *FSM) applyTxn(buf []byte, index uint64) interface{} { var req structs.TxnRequest if err := structs.Decode(buf, &req); err != nil { panic(fmt.Errorf("failed to decode request: %v", err)) @@ -338,7 +338,7 @@ func (c *consulFSM) applyTxn(buf []byte, index uint64) interface{} { } } -func (c *consulFSM) applyAutopilotUpdate(buf []byte, index uint64) interface{} { +func (c *FSM) applyAutopilotUpdate(buf []byte, index uint64) interface{} { var req structs.AutopilotSetConfigRequest if err := structs.Decode(buf, &req); err != nil { panic(fmt.Errorf("failed to decode request: %v", err)) @@ -356,7 +356,7 @@ func (c *consulFSM) applyAutopilotUpdate(buf []byte, index uint64) interface{} { return c.state.AutopilotSetConfig(index, &req.Config) } -func (c *consulFSM) Snapshot() (raft.FSMSnapshot, error) { +func (c *FSM) Snapshot() (raft.FSMSnapshot, error) { defer func(start time.Time) { c.logger.Printf("[INFO] consul.fsm: snapshot created in %v", time.Since(start)) }(time.Now()) @@ -366,7 +366,7 @@ func (c *consulFSM) Snapshot() (raft.FSMSnapshot, error) { // Restore streams in the snapshot and replaces the current state store with a // new one based on the snapshot if all goes OK during the restore. -func (c *consulFSM) Restore(old io.ReadCloser) error { +func (c *FSM) Restore(old io.ReadCloser) error { defer old.Close() // Create a new state store. diff --git a/agent/consul/fsm_test.go b/agent/consul/fsm/fsm_test.go similarity index 96% rename from agent/consul/fsm_test.go rename to agent/consul/fsm/fsm_test.go index 1e38557903..6b8cf9a129 100644 --- a/agent/consul/fsm_test.go +++ b/agent/consul/fsm/fsm_test.go @@ -1,8 +1,9 @@ -package consul +package fsm import ( "bytes" "fmt" + "math/rand" "os" "reflect" "testing" @@ -15,6 +16,7 @@ import ( "github.com/hashicorp/consul/types" "github.com/hashicorp/go-uuid" "github.com/hashicorp/raft" + "github.com/hashicorp/serf/coordinate" "github.com/pascaldekloe/goe/verify" ) @@ -53,9 +55,20 @@ func generateUUID() (ret string) { return ret } +func generateRandomCoordinate() *coordinate.Coordinate { + config := coordinate.DefaultConfig() + coord := coordinate.NewCoordinate(config) + for i := range coord.Vec { + coord.Vec[i] = rand.NormFloat64() + } + coord.Error = rand.NormFloat64() + coord.Adjustment = rand.NormFloat64() + return coord +} + func TestFSM_RegisterNode(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -99,7 +112,7 @@ func TestFSM_RegisterNode(t *testing.T) { func TestFSM_RegisterNode_Service(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -162,7 +175,7 @@ func TestFSM_RegisterNode_Service(t *testing.T) { func TestFSM_DeregisterService(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -224,7 +237,7 @@ func TestFSM_DeregisterService(t *testing.T) { func TestFSM_DeregisterCheck(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -286,7 +299,7 @@ func TestFSM_DeregisterCheck(t *testing.T) { func TestFSM_DeregisterNode(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -363,7 +376,7 @@ func TestFSM_DeregisterNode(t *testing.T) { func TestFSM_SnapshotRestore(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -459,7 +472,7 @@ func TestFSM_SnapshotRestore(t *testing.T) { } // Try to restore on a new FSM - fsm2, err := NewFSM(nil, os.Stderr) + fsm2, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -636,7 +649,7 @@ func TestFSM_SnapshotRestore(t *testing.T) { func TestFSM_BadRestore(t *testing.T) { t.Parallel() // Create an FSM with some state. - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -674,7 +687,7 @@ func TestFSM_BadRestore(t *testing.T) { func TestFSM_KVSDelete(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -720,7 +733,7 @@ func TestFSM_KVSDelete(t *testing.T) { func TestFSM_KVSDeleteTree(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -767,7 +780,7 @@ func TestFSM_KVSDeleteTree(t *testing.T) { func TestFSM_KVSDeleteCheckAndSet(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -823,7 +836,7 @@ func TestFSM_KVSDeleteCheckAndSet(t *testing.T) { func TestFSM_KVSCheckAndSet(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -880,7 +893,7 @@ func TestFSM_KVSCheckAndSet(t *testing.T) { func TestFSM_CoordinateUpdate(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -921,7 +934,7 @@ func TestFSM_CoordinateUpdate(t *testing.T) { func TestFSM_SessionCreate_Destroy(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -1001,7 +1014,7 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) { func TestFSM_KVSLock(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -1046,7 +1059,7 @@ func TestFSM_KVSLock(t *testing.T) { func TestFSM_KVSUnlock(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -1109,7 +1122,7 @@ func TestFSM_KVSUnlock(t *testing.T) { func TestFSM_ACL_CRUD(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -1229,7 +1242,7 @@ func TestFSM_ACL_CRUD(t *testing.T) { func TestFSM_PreparedQuery_CRUD(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -1327,7 +1340,7 @@ func TestFSM_PreparedQuery_CRUD(t *testing.T) { func TestFSM_TombstoneReap(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -1375,7 +1388,7 @@ func TestFSM_TombstoneReap(t *testing.T) { func TestFSM_Txn(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -1417,7 +1430,7 @@ func TestFSM_Txn(t *testing.T) { func TestFSM_Autopilot(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } @@ -1479,7 +1492,7 @@ func TestFSM_Autopilot(t *testing.T) { func TestFSM_IgnoreUnknown(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/consul/issue_test.go b/agent/consul/issue_test.go index 9f7581ed0b..f514642ab9 100644 --- a/agent/consul/issue_test.go +++ b/agent/consul/issue_test.go @@ -5,14 +5,25 @@ import ( "reflect" "testing" + consulfsm "github.com/hashicorp/consul/agent/consul/fsm" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" + "github.com/hashicorp/raft" ) +func makeLog(buf []byte) *raft.Log { + return &raft.Log{ + Index: 1, + Term: 1, + Type: raft.LogCommand, + Data: buf, + } +} + // Testing for GH-300 and GH-279 func TestHealthCheckRace(t *testing.T) { t.Parallel() - fsm, err := NewFSM(nil, os.Stderr) + fsm, err := consulfsm.New(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/consul/rpc_test.go b/agent/consul/rpc_test.go index 2551878313..7baa3f2359 100644 --- a/agent/consul/rpc_test.go +++ b/agent/consul/rpc_test.go @@ -77,6 +77,24 @@ func TestRPC_NoLeader_Retry(t *testing.T) { } } +type MockSink struct { + *bytes.Buffer + cancel bool +} + +func (m *MockSink) ID() string { + return "Mock" +} + +func (m *MockSink) Cancel() error { + m.cancel = true + return nil +} + +func (m *MockSink) Close() error { + return nil +} + func TestRPC_blockingQuery(t *testing.T) { t.Parallel() dir, s := testServer(t) diff --git a/agent/consul/server.go b/agent/consul/server.go index 5edc49294b..bb64e96f7e 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -18,6 +18,7 @@ import ( "time" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/fsm" "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" @@ -122,7 +123,7 @@ type Server struct { // fsm is the state machine used with Raft to provide // strong consistency. - fsm *consulFSM + fsm *fsm.FSM // Logger uses the provided LogOutput logger *log.Logger @@ -447,7 +448,7 @@ func (s *Server) setupRaft() error { // Create the FSM. var err error - s.fsm, err = NewFSM(s.tombstoneGC, s.config.LogOutput) + s.fsm, err = fsm.New(s.tombstoneGC, s.config.LogOutput) if err != nil { return err } @@ -554,7 +555,7 @@ func (s *Server) setupRaft() error { return fmt.Errorf("recovery failed to parse peers.json: %v", err) } - tmpFsm, err := NewFSM(s.tombstoneGC, s.config.LogOutput) + tmpFsm, err := fsm.New(s.tombstoneGC, s.config.LogOutput) if err != nil { return fmt.Errorf("recovery failed to make temp FSM: %v", err) } diff --git a/agent/consul/session_ttl_test.go b/agent/consul/session_ttl_test.go index 2bc7dc7aae..5ede76462e 100644 --- a/agent/consul/session_ttl_test.go +++ b/agent/consul/session_ttl_test.go @@ -1,6 +1,7 @@ package consul import ( + "fmt" "os" "strings" "testing" @@ -9,9 +10,18 @@ import ( "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testutil/retry" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/net-rpc-msgpackrpc" ) +func generateUUID() (ret string) { + var err error + if ret, err = uuid.GenerateUUID(); err != nil { + panic(fmt.Sprintf("Unable to generate a UUID, %v", err)) + } + return ret +} + func TestInitializeSessionTimers(t *testing.T) { t.Parallel() dir1, s1 := testServer(t) From c8e763667f5f7ce5fe257b46f51e23d8e8b94de4 Mon Sep 17 00:00:00 2001 From: James Phillips Date: Wed, 29 Nov 2017 12:43:27 -0800 Subject: [PATCH 12/14] Creates a registration mechanism for FSM commands. --- agent/consul/fsm/commands_oss.go | 263 ++++++ agent/consul/fsm/commands_oss_test.go | 1149 +++++++++++++++++++++++++ agent/consul/fsm/fsm.go | 311 +------ agent/consul/fsm/fsm_test.go | 1137 ------------------------ 4 files changed, 1457 insertions(+), 1403 deletions(-) create mode 100644 agent/consul/fsm/commands_oss.go create mode 100644 agent/consul/fsm/commands_oss_test.go diff --git a/agent/consul/fsm/commands_oss.go b/agent/consul/fsm/commands_oss.go new file mode 100644 index 0000000000..2029f1723e --- /dev/null +++ b/agent/consul/fsm/commands_oss.go @@ -0,0 +1,263 @@ +package fsm + +import ( + "fmt" + "time" + + "github.com/armon/go-metrics" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/api" +) + +func init() { + registerCommand(structs.RegisterRequestType, (*FSM).applyRegister) + registerCommand(structs.DeregisterRequestType, (*FSM).applyDeregister) + registerCommand(structs.KVSRequestType, (*FSM).applyKVSOperation) + registerCommand(structs.SessionRequestType, (*FSM).applySessionOperation) + registerCommand(structs.ACLRequestType, (*FSM).applyACLOperation) + registerCommand(structs.TombstoneRequestType, (*FSM).applyTombstoneOperation) + registerCommand(structs.CoordinateBatchUpdateType, (*FSM).applyCoordinateBatchUpdate) + registerCommand(structs.PreparedQueryRequestType, (*FSM).applyPreparedQueryOperation) + registerCommand(structs.TxnRequestType, (*FSM).applyTxn) + registerCommand(structs.AutopilotRequestType, (*FSM).applyAutopilotUpdate) +} + +func (c *FSM) applyRegister(buf []byte, index uint64) interface{} { + defer metrics.MeasureSince([]string{"consul", "fsm", "register"}, time.Now()) + defer metrics.MeasureSince([]string{"fsm", "register"}, time.Now()) + var req structs.RegisterRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + + // Apply all updates in a single transaction + if err := c.state.EnsureRegistration(index, &req); err != nil { + c.logger.Printf("[WARN] consul.fsm: EnsureRegistration failed: %v", err) + return err + } + return nil +} + +func (c *FSM) applyDeregister(buf []byte, index uint64) interface{} { + defer metrics.MeasureSince([]string{"consul", "fsm", "deregister"}, time.Now()) + defer metrics.MeasureSince([]string{"fsm", "deregister"}, time.Now()) + var req structs.DeregisterRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + + // Either remove the service entry or the whole node. The precedence + // here is also baked into vetDeregisterWithACL() in acl.go, so if you + // make changes here, be sure to also adjust the code over there. + if req.ServiceID != "" { + if err := c.state.DeleteService(index, req.Node, req.ServiceID); err != nil { + c.logger.Printf("[WARN] consul.fsm: DeleteNodeService failed: %v", err) + return err + } + } else if req.CheckID != "" { + if err := c.state.DeleteCheck(index, req.Node, req.CheckID); err != nil { + c.logger.Printf("[WARN] consul.fsm: DeleteNodeCheck failed: %v", err) + return err + } + } else { + if err := c.state.DeleteNode(index, req.Node); err != nil { + c.logger.Printf("[WARN] consul.fsm: DeleteNode failed: %v", err) + return err + } + } + return nil +} + +func (c *FSM) applyKVSOperation(buf []byte, index uint64) interface{} { + var req structs.KVSRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "kvs"}, time.Now(), + []metrics.Label{{Name: "op", Value: string(req.Op)}}) + defer metrics.MeasureSinceWithLabels([]string{"fsm", "kvs"}, time.Now(), + []metrics.Label{{Name: "op", Value: string(req.Op)}}) + switch req.Op { + case api.KVSet: + return c.state.KVSSet(index, &req.DirEnt) + case api.KVDelete: + return c.state.KVSDelete(index, req.DirEnt.Key) + case api.KVDeleteCAS: + act, err := c.state.KVSDeleteCAS(index, req.DirEnt.ModifyIndex, req.DirEnt.Key) + if err != nil { + return err + } + return act + case api.KVDeleteTree: + return c.state.KVSDeleteTree(index, req.DirEnt.Key) + case api.KVCAS: + act, err := c.state.KVSSetCAS(index, &req.DirEnt) + if err != nil { + return err + } + return act + case api.KVLock: + act, err := c.state.KVSLock(index, &req.DirEnt) + if err != nil { + return err + } + return act + case api.KVUnlock: + act, err := c.state.KVSUnlock(index, &req.DirEnt) + if err != nil { + return err + } + return act + default: + err := fmt.Errorf("Invalid KVS operation '%s'", req.Op) + c.logger.Printf("[WARN] consul.fsm: %v", err) + return err + } +} + +func (c *FSM) applySessionOperation(buf []byte, index uint64) interface{} { + var req structs.SessionRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "session"}, time.Now(), + []metrics.Label{{Name: "op", Value: string(req.Op)}}) + defer metrics.MeasureSinceWithLabels([]string{"fsm", "session"}, time.Now(), + []metrics.Label{{Name: "op", Value: string(req.Op)}}) + switch req.Op { + case structs.SessionCreate: + if err := c.state.SessionCreate(index, &req.Session); err != nil { + return err + } + return req.Session.ID + case structs.SessionDestroy: + return c.state.SessionDestroy(index, req.Session.ID) + default: + c.logger.Printf("[WARN] consul.fsm: Invalid Session operation '%s'", req.Op) + return fmt.Errorf("Invalid Session operation '%s'", req.Op) + } +} + +func (c *FSM) applyACLOperation(buf []byte, index uint64) interface{} { + var req structs.ACLRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "acl"}, time.Now(), + []metrics.Label{{Name: "op", Value: string(req.Op)}}) + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl"}, time.Now(), + []metrics.Label{{Name: "op", Value: string(req.Op)}}) + switch req.Op { + case structs.ACLBootstrapInit: + enabled, err := c.state.ACLBootstrapInit(index) + if err != nil { + return err + } + return enabled + case structs.ACLBootstrapNow: + if err := c.state.ACLBootstrap(index, &req.ACL); err != nil { + return err + } + return &req.ACL + case structs.ACLForceSet, structs.ACLSet: + if err := c.state.ACLSet(index, &req.ACL); err != nil { + return err + } + return req.ACL.ID + case structs.ACLDelete: + return c.state.ACLDelete(index, req.ACL.ID) + default: + c.logger.Printf("[WARN] consul.fsm: Invalid ACL operation '%s'", req.Op) + return fmt.Errorf("Invalid ACL operation '%s'", req.Op) + } +} + +func (c *FSM) applyTombstoneOperation(buf []byte, index uint64) interface{} { + var req structs.TombstoneRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "tombstone"}, time.Now(), + []metrics.Label{{Name: "op", Value: string(req.Op)}}) + defer metrics.MeasureSinceWithLabels([]string{"fsm", "tombstone"}, time.Now(), + []metrics.Label{{Name: "op", Value: string(req.Op)}}) + switch req.Op { + case structs.TombstoneReap: + return c.state.ReapTombstones(req.ReapIndex) + default: + c.logger.Printf("[WARN] consul.fsm: Invalid Tombstone operation '%s'", req.Op) + return fmt.Errorf("Invalid Tombstone operation '%s'", req.Op) + } +} + +// applyCoordinateBatchUpdate processes a batch of coordinate updates and applies +// them in a single underlying transaction. This interface isn't 1:1 with the outer +// update interface that the coordinate endpoint exposes, so we made it single +// purpose and avoided the opcode convention. +func (c *FSM) applyCoordinateBatchUpdate(buf []byte, index uint64) interface{} { + var updates structs.Coordinates + if err := structs.Decode(buf, &updates); err != nil { + panic(fmt.Errorf("failed to decode batch updates: %v", err)) + } + defer metrics.MeasureSince([]string{"consul", "fsm", "coordinate", "batch-update"}, time.Now()) + defer metrics.MeasureSince([]string{"fsm", "coordinate", "batch-update"}, time.Now()) + if err := c.state.CoordinateBatchUpdate(index, updates); err != nil { + return err + } + return nil +} + +// applyPreparedQueryOperation applies the given prepared query operation to the +// state store. +func (c *FSM) applyPreparedQueryOperation(buf []byte, index uint64) interface{} { + var req structs.PreparedQueryRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + + defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "prepared-query"}, time.Now(), + []metrics.Label{{Name: "op", Value: string(req.Op)}}) + defer metrics.MeasureSinceWithLabels([]string{"fsm", "prepared-query"}, time.Now(), + []metrics.Label{{Name: "op", Value: string(req.Op)}}) + switch req.Op { + case structs.PreparedQueryCreate, structs.PreparedQueryUpdate: + return c.state.PreparedQuerySet(index, req.Query) + case structs.PreparedQueryDelete: + return c.state.PreparedQueryDelete(index, req.Query.ID) + default: + c.logger.Printf("[WARN] consul.fsm: Invalid PreparedQuery operation '%s'", req.Op) + return fmt.Errorf("Invalid PreparedQuery operation '%s'", req.Op) + } +} + +func (c *FSM) applyTxn(buf []byte, index uint64) interface{} { + var req structs.TxnRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSince([]string{"consul", "fsm", "txn"}, time.Now()) + defer metrics.MeasureSince([]string{"fsm", "txn"}, time.Now()) + results, errors := c.state.TxnRW(index, req.Ops) + return structs.TxnResponse{ + Results: results, + Errors: errors, + } +} + +func (c *FSM) applyAutopilotUpdate(buf []byte, index uint64) interface{} { + var req structs.AutopilotSetConfigRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSince([]string{"consul", "fsm", "autopilot"}, time.Now()) + defer metrics.MeasureSince([]string{"fsm", "autopilot"}, time.Now()) + + if req.CAS { + act, err := c.state.AutopilotCASConfig(index, req.Config.ModifyIndex, &req.Config) + if err != nil { + return err + } + return act + } + return c.state.AutopilotSetConfig(index, &req.Config) +} diff --git a/agent/consul/fsm/commands_oss_test.go b/agent/consul/fsm/commands_oss_test.go new file mode 100644 index 0000000000..b98d9ea5e3 --- /dev/null +++ b/agent/consul/fsm/commands_oss_test.go @@ -0,0 +1,1149 @@ +package fsm + +import ( + "fmt" + "math/rand" + "os" + "reflect" + "testing" + "time" + + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/types" + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/serf/coordinate" + "github.com/pascaldekloe/goe/verify" +) + +func generateUUID() (ret string) { + var err error + if ret, err = uuid.GenerateUUID(); err != nil { + panic(fmt.Sprintf("Unable to generate a UUID, %v", err)) + } + return ret +} + +func generateRandomCoordinate() *coordinate.Coordinate { + config := coordinate.DefaultConfig() + coord := coordinate.NewCoordinate(config) + for i := range coord.Vec { + coord.Vec[i] = rand.NormFloat64() + } + coord.Error = rand.NormFloat64() + coord.Adjustment = rand.NormFloat64() + return coord +} + +func TestFSM_RegisterNode(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + req := structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + } + buf, err := structs.Encode(structs.RegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify we are registered + _, node, err := fsm.state.GetNode("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if node == nil { + t.Fatalf("not found!") + } + if node.ModifyIndex != 1 { + t.Fatalf("bad index: %d", node.ModifyIndex) + } + + // Verify service registered + _, services, err := fsm.state.NodeServices(nil, "foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(services.Services) != 0 { + t.Fatalf("Services: %v", services) + } +} + +func TestFSM_RegisterNode_Service(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + req := structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "db", + Service: "db", + Tags: []string{"master"}, + Port: 8000, + }, + Check: &structs.HealthCheck{ + Node: "foo", + CheckID: "db", + Name: "db connectivity", + Status: api.HealthPassing, + ServiceID: "db", + }, + } + buf, err := structs.Encode(structs.RegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify we are registered + _, node, err := fsm.state.GetNode("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if node == nil { + t.Fatalf("not found!") + } + + // Verify service registered + _, services, err := fsm.state.NodeServices(nil, "foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if _, ok := services.Services["db"]; !ok { + t.Fatalf("not registered!") + } + + // Verify check + _, checks, err := fsm.state.NodeChecks(nil, "foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if checks[0].CheckID != "db" { + t.Fatalf("not registered!") + } +} + +func TestFSM_DeregisterService(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + req := structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "db", + Service: "db", + Tags: []string{"master"}, + Port: 8000, + }, + } + buf, err := structs.Encode(structs.RegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + dereg := structs.DeregisterRequest{ + Datacenter: "dc1", + Node: "foo", + ServiceID: "db", + } + buf, err = structs.Encode(structs.DeregisterRequestType, dereg) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp = fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify we are registered + _, node, err := fsm.state.GetNode("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if node == nil { + t.Fatalf("not found!") + } + + // Verify service not registered + _, services, err := fsm.state.NodeServices(nil, "foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if _, ok := services.Services["db"]; ok { + t.Fatalf("db registered!") + } +} + +func TestFSM_DeregisterCheck(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + req := structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Check: &structs.HealthCheck{ + Node: "foo", + CheckID: "mem", + Name: "memory util", + Status: api.HealthPassing, + }, + } + buf, err := structs.Encode(structs.RegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + dereg := structs.DeregisterRequest{ + Datacenter: "dc1", + Node: "foo", + CheckID: "mem", + } + buf, err = structs.Encode(structs.DeregisterRequestType, dereg) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp = fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify we are registered + _, node, err := fsm.state.GetNode("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if node == nil { + t.Fatalf("not found!") + } + + // Verify check not registered + _, checks, err := fsm.state.NodeChecks(nil, "foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(checks) != 0 { + t.Fatalf("check registered!") + } +} + +func TestFSM_DeregisterNode(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + req := structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "db", + Service: "db", + Tags: []string{"master"}, + Port: 8000, + }, + Check: &structs.HealthCheck{ + Node: "foo", + CheckID: "db", + Name: "db connectivity", + Status: api.HealthPassing, + ServiceID: "db", + }, + } + buf, err := structs.Encode(structs.RegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + dereg := structs.DeregisterRequest{ + Datacenter: "dc1", + Node: "foo", + } + buf, err = structs.Encode(structs.DeregisterRequestType, dereg) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp = fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify we are not registered + _, node, err := fsm.state.GetNode("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if node != nil { + t.Fatalf("found!") + } + + // Verify service not registered + _, services, err := fsm.state.NodeServices(nil, "foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if services != nil { + t.Fatalf("Services: %v", services) + } + + // Verify checks not registered + _, checks, err := fsm.state.NodeChecks(nil, "foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(checks) != 0 { + t.Fatalf("Services: %v", services) + } +} + +func TestFSM_KVSDelete(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + req := structs.KVSRequest{ + Datacenter: "dc1", + Op: api.KVSet, + DirEnt: structs.DirEntry{ + Key: "/test/path", + Flags: 0, + Value: []byte("test"), + }, + } + buf, err := structs.Encode(structs.KVSRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Run the delete + req.Op = api.KVDelete + buf, err = structs.Encode(structs.KVSRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify key is not set + _, d, err := fsm.state.KVSGet(nil, "/test/path") + if err != nil { + t.Fatalf("err: %v", err) + } + if d != nil { + t.Fatalf("key present") + } +} + +func TestFSM_KVSDeleteTree(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + req := structs.KVSRequest{ + Datacenter: "dc1", + Op: api.KVSet, + DirEnt: structs.DirEntry{ + Key: "/test/path", + Flags: 0, + Value: []byte("test"), + }, + } + buf, err := structs.Encode(structs.KVSRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Run the delete tree + req.Op = api.KVDeleteTree + req.DirEnt.Key = "/test" + buf, err = structs.Encode(structs.KVSRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify key is not set + _, d, err := fsm.state.KVSGet(nil, "/test/path") + if err != nil { + t.Fatalf("err: %v", err) + } + if d != nil { + t.Fatalf("key present") + } +} + +func TestFSM_KVSDeleteCheckAndSet(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + req := structs.KVSRequest{ + Datacenter: "dc1", + Op: api.KVSet, + DirEnt: structs.DirEntry{ + Key: "/test/path", + Flags: 0, + Value: []byte("test"), + }, + } + buf, err := structs.Encode(structs.KVSRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify key is set + _, d, err := fsm.state.KVSGet(nil, "/test/path") + if err != nil { + t.Fatalf("err: %v", err) + } + if d == nil { + t.Fatalf("key missing") + } + + // Run the check-and-set + req.Op = api.KVDeleteCAS + req.DirEnt.ModifyIndex = d.ModifyIndex + buf, err = structs.Encode(structs.KVSRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = fsm.Apply(makeLog(buf)) + if resp.(bool) != true { + t.Fatalf("resp: %v", resp) + } + + // Verify key is gone + _, d, err = fsm.state.KVSGet(nil, "/test/path") + if err != nil { + t.Fatalf("err: %v", err) + } + if d != nil { + t.Fatalf("bad: %v", d) + } +} + +func TestFSM_KVSCheckAndSet(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + req := structs.KVSRequest{ + Datacenter: "dc1", + Op: api.KVSet, + DirEnt: structs.DirEntry{ + Key: "/test/path", + Flags: 0, + Value: []byte("test"), + }, + } + buf, err := structs.Encode(structs.KVSRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify key is set + _, d, err := fsm.state.KVSGet(nil, "/test/path") + if err != nil { + t.Fatalf("err: %v", err) + } + if d == nil { + t.Fatalf("key missing") + } + + // Run the check-and-set + req.Op = api.KVCAS + req.DirEnt.ModifyIndex = d.ModifyIndex + req.DirEnt.Value = []byte("zip") + buf, err = structs.Encode(structs.KVSRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = fsm.Apply(makeLog(buf)) + if resp.(bool) != true { + t.Fatalf("resp: %v", resp) + } + + // Verify key is updated + _, d, err = fsm.state.KVSGet(nil, "/test/path") + if err != nil { + t.Fatalf("err: %v", err) + } + if string(d.Value) != "zip" { + t.Fatalf("bad: %v", d) + } +} + +func TestFSM_KVSLock(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + session := &structs.Session{ID: generateUUID(), Node: "foo"} + fsm.state.SessionCreate(2, session) + + req := structs.KVSRequest{ + Datacenter: "dc1", + Op: api.KVLock, + DirEnt: structs.DirEntry{ + Key: "/test/path", + Value: []byte("test"), + Session: session.ID, + }, + } + buf, err := structs.Encode(structs.KVSRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if resp != true { + t.Fatalf("resp: %v", resp) + } + + // Verify key is locked + _, d, err := fsm.state.KVSGet(nil, "/test/path") + if err != nil { + t.Fatalf("err: %v", err) + } + if d == nil { + t.Fatalf("missing") + } + if d.LockIndex != 1 { + t.Fatalf("bad: %v", *d) + } + if d.Session != session.ID { + t.Fatalf("bad: %v", *d) + } +} + +func TestFSM_KVSUnlock(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + session := &structs.Session{ID: generateUUID(), Node: "foo"} + fsm.state.SessionCreate(2, session) + + req := structs.KVSRequest{ + Datacenter: "dc1", + Op: api.KVLock, + DirEnt: structs.DirEntry{ + Key: "/test/path", + Value: []byte("test"), + Session: session.ID, + }, + } + buf, err := structs.Encode(structs.KVSRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if resp != true { + t.Fatalf("resp: %v", resp) + } + + req = structs.KVSRequest{ + Datacenter: "dc1", + Op: api.KVUnlock, + DirEnt: structs.DirEntry{ + Key: "/test/path", + Value: []byte("test"), + Session: session.ID, + }, + } + buf, err = structs.Encode(structs.KVSRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = fsm.Apply(makeLog(buf)) + if resp != true { + t.Fatalf("resp: %v", resp) + } + + // Verify key is unlocked + _, d, err := fsm.state.KVSGet(nil, "/test/path") + if err != nil { + t.Fatalf("err: %v", err) + } + if d == nil { + t.Fatalf("missing") + } + if d.LockIndex != 1 { + t.Fatalf("bad: %v", *d) + } + if d.Session != "" { + t.Fatalf("bad: %v", *d) + } +} + +func TestFSM_CoordinateUpdate(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Register some nodes. + fsm.state.EnsureNode(1, &structs.Node{Node: "node1", Address: "127.0.0.1"}) + fsm.state.EnsureNode(2, &structs.Node{Node: "node2", Address: "127.0.0.1"}) + + // Write a batch of two coordinates. + updates := structs.Coordinates{ + &structs.Coordinate{ + Node: "node1", + Coord: generateRandomCoordinate(), + }, + &structs.Coordinate{ + Node: "node2", + Coord: generateRandomCoordinate(), + }, + } + buf, err := structs.Encode(structs.CoordinateBatchUpdateType, updates) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Read back the two coordinates to make sure they got updated. + _, coords, err := fsm.state.Coordinates(nil) + if err != nil { + t.Fatalf("err: %s", err) + } + if !reflect.DeepEqual(coords, updates) { + t.Fatalf("bad: %#v", coords) + } +} + +func TestFSM_SessionCreate_Destroy(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + fsm.state.EnsureCheck(2, &structs.HealthCheck{ + Node: "foo", + CheckID: "web", + Status: api.HealthPassing, + }) + + // Create a new session + req := structs.SessionRequest{ + Datacenter: "dc1", + Op: structs.SessionCreate, + Session: structs.Session{ + ID: generateUUID(), + Node: "foo", + Checks: []types.CheckID{"web"}, + }, + } + buf, err := structs.Encode(structs.SessionRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if err, ok := resp.(error); ok { + t.Fatalf("resp: %v", err) + } + + // Get the session + id := resp.(string) + _, session, err := fsm.state.SessionGet(nil, id) + if err != nil { + t.Fatalf("err: %v", err) + } + if session == nil { + t.Fatalf("missing") + } + + // Verify the session + if session.ID != id { + t.Fatalf("bad: %v", *session) + } + if session.Node != "foo" { + t.Fatalf("bad: %v", *session) + } + if session.Checks[0] != "web" { + t.Fatalf("bad: %v", *session) + } + + // Try to destroy + destroy := structs.SessionRequest{ + Datacenter: "dc1", + Op: structs.SessionDestroy, + Session: structs.Session{ + ID: id, + }, + } + buf, err = structs.Encode(structs.SessionRequestType, destroy) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + _, session, err = fsm.state.SessionGet(nil, id) + if err != nil { + t.Fatalf("err: %v", err) + } + if session != nil { + t.Fatalf("should be destroyed") + } +} + +func TestFSM_ACL_CRUD(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Create a new ACL. + req := structs.ACLRequest{ + Datacenter: "dc1", + Op: structs.ACLSet, + ACL: structs.ACL{ + ID: generateUUID(), + Name: "User token", + Type: structs.ACLTypeClient, + }, + } + buf, err := structs.Encode(structs.ACLRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if err, ok := resp.(error); ok { + t.Fatalf("resp: %v", err) + } + + // Get the ACL. + id := resp.(string) + _, acl, err := fsm.state.ACLGet(nil, id) + if err != nil { + t.Fatalf("err: %v", err) + } + if acl == nil { + t.Fatalf("missing") + } + + // Verify the ACL. + if acl.ID != id { + t.Fatalf("bad: %v", *acl) + } + if acl.Name != "User token" { + t.Fatalf("bad: %v", *acl) + } + if acl.Type != structs.ACLTypeClient { + t.Fatalf("bad: %v", *acl) + } + + // Try to destroy. + destroy := structs.ACLRequest{ + Datacenter: "dc1", + Op: structs.ACLDelete, + ACL: structs.ACL{ + ID: id, + }, + } + buf, err = structs.Encode(structs.ACLRequestType, destroy) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + _, acl, err = fsm.state.ACLGet(nil, id) + if err != nil { + t.Fatalf("err: %v", err) + } + if acl != nil { + t.Fatalf("should be destroyed") + } + + // Initialize bootstrap (should work since we haven't made a management + // token). + init := structs.ACLRequest{ + Datacenter: "dc1", + Op: structs.ACLBootstrapInit, + } + buf, err = structs.Encode(structs.ACLRequestType, init) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = fsm.Apply(makeLog(buf)) + if enabled, ok := resp.(bool); !ok || !enabled { + t.Fatalf("resp: %v", resp) + } + gotB, err := fsm.state.ACLGetBootstrap() + if err != nil { + t.Fatalf("err: %v", err) + } + wantB := &structs.ACLBootstrap{ + AllowBootstrap: true, + RaftIndex: gotB.RaftIndex, + } + verify.Values(t, "", gotB, wantB) + + // Do a bootstrap. + bootstrap := structs.ACLRequest{ + Datacenter: "dc1", + Op: structs.ACLBootstrapNow, + ACL: structs.ACL{ + ID: generateUUID(), + Name: "Bootstrap Token", + Type: structs.ACLTypeManagement, + }, + } + buf, err = structs.Encode(structs.ACLRequestType, bootstrap) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = fsm.Apply(makeLog(buf)) + respACL, ok := resp.(*structs.ACL) + if !ok { + t.Fatalf("resp: %v", resp) + } + bootstrap.ACL.CreateIndex = respACL.CreateIndex + bootstrap.ACL.ModifyIndex = respACL.ModifyIndex + verify.Values(t, "", respACL, &bootstrap.ACL) +} + +func TestFSM_PreparedQuery_CRUD(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Register a service to query on. + fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + fsm.state.EnsureService(2, "foo", &structs.NodeService{ID: "web", Service: "web", Tags: nil, Address: "127.0.0.1", Port: 80}) + + // Create a new query. + query := structs.PreparedQueryRequest{ + Op: structs.PreparedQueryCreate, + Query: &structs.PreparedQuery{ + ID: generateUUID(), + Service: structs.ServiceQuery{ + Service: "web", + }, + }, + } + { + buf, err := structs.Encode(structs.PreparedQueryRequestType, query) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + } + + // Verify it's in the state store. + { + _, actual, err := fsm.state.PreparedQueryGet(nil, query.Query.ID) + if err != nil { + t.Fatalf("err: %s", err) + } + + actual.CreateIndex, actual.ModifyIndex = 0, 0 + if !reflect.DeepEqual(actual, query.Query) { + t.Fatalf("bad: %v", actual) + } + } + + // Make an update to the query. + query.Op = structs.PreparedQueryUpdate + query.Query.Name = "my-query" + { + buf, err := structs.Encode(structs.PreparedQueryRequestType, query) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + } + + // Verify the update. + { + _, actual, err := fsm.state.PreparedQueryGet(nil, query.Query.ID) + if err != nil { + t.Fatalf("err: %s", err) + } + + actual.CreateIndex, actual.ModifyIndex = 0, 0 + if !reflect.DeepEqual(actual, query.Query) { + t.Fatalf("bad: %v", actual) + } + } + + // Delete the query. + query.Op = structs.PreparedQueryDelete + { + buf, err := structs.Encode(structs.PreparedQueryRequestType, query) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + } + + // Make sure it's gone. + { + _, actual, err := fsm.state.PreparedQueryGet(nil, query.Query.ID) + if err != nil { + t.Fatalf("err: %s", err) + } + + if actual != nil { + t.Fatalf("bad: %v", actual) + } + } +} + +func TestFSM_TombstoneReap(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Create some tombstones + fsm.state.KVSSet(11, &structs.DirEntry{ + Key: "/remove", + Value: []byte("foo"), + }) + fsm.state.KVSDelete(12, "/remove") + idx, _, err := fsm.state.KVSList(nil, "/remove") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 12 { + t.Fatalf("bad index: %d", idx) + } + + // Create a new reap request + req := structs.TombstoneRequest{ + Datacenter: "dc1", + Op: structs.TombstoneReap, + ReapIndex: 12, + } + buf, err := structs.Encode(structs.TombstoneRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if err, ok := resp.(error); ok { + t.Fatalf("resp: %v", err) + } + + // Verify the tombstones are gone + snap := fsm.state.Snapshot() + defer snap.Close() + stones, err := snap.Tombstones() + if err != nil { + t.Fatalf("err: %s", err) + } + if stones.Next() != nil { + t.Fatalf("unexpected extra tombstones") + } +} + +func TestFSM_Txn(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Set a key using a transaction. + req := structs.TxnRequest{ + Datacenter: "dc1", + Ops: structs.TxnOps{ + &structs.TxnOp{ + KV: &structs.TxnKVOp{ + Verb: api.KVSet, + DirEnt: structs.DirEntry{ + Key: "/test/path", + Flags: 0, + Value: []byte("test"), + }, + }, + }, + }, + } + buf, err := structs.Encode(structs.TxnRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if _, ok := resp.(structs.TxnResponse); !ok { + t.Fatalf("bad response type: %T", resp) + } + + // Verify key is set directly in the state store. + _, d, err := fsm.state.KVSGet(nil, "/test/path") + if err != nil { + t.Fatalf("err: %v", err) + } + if d == nil { + t.Fatalf("missing") + } +} + +func TestFSM_Autopilot(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Set the autopilot config using a request. + req := structs.AutopilotSetConfigRequest{ + Datacenter: "dc1", + Config: structs.AutopilotConfig{ + CleanupDeadServers: true, + LastContactThreshold: 10 * time.Second, + MaxTrailingLogs: 300, + }, + } + buf, err := structs.Encode(structs.AutopilotRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if _, ok := resp.(error); ok { + t.Fatalf("bad: %v", resp) + } + + // Verify key is set directly in the state store. + _, config, err := fsm.state.AutopilotConfig() + if err != nil { + t.Fatalf("err: %v", err) + } + if config.CleanupDeadServers != req.Config.CleanupDeadServers { + t.Fatalf("bad: %v", config.CleanupDeadServers) + } + if config.LastContactThreshold != req.Config.LastContactThreshold { + t.Fatalf("bad: %v", config.LastContactThreshold) + } + if config.MaxTrailingLogs != req.Config.MaxTrailingLogs { + t.Fatalf("bad: %v", config.MaxTrailingLogs) + } + + // Now use CAS and provide an old index + req.CAS = true + req.Config.CleanupDeadServers = false + req.Config.ModifyIndex = config.ModifyIndex - 1 + buf, err = structs.Encode(structs.AutopilotRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = fsm.Apply(makeLog(buf)) + if _, ok := resp.(error); ok { + t.Fatalf("bad: %v", resp) + } + + _, config, err = fsm.state.AutopilotConfig() + if err != nil { + t.Fatalf("err: %v", err) + } + if !config.CleanupDeadServers { + t.Fatalf("bad: %v", config.CleanupDeadServers) + } +} diff --git a/agent/consul/fsm/fsm.go b/agent/consul/fsm/fsm.go index 6defe6ca84..ea71abb607 100644 --- a/agent/consul/fsm/fsm.go +++ b/agent/consul/fsm/fsm.go @@ -10,7 +10,6 @@ import ( "github.com/armon/go-metrics" "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/structs" - "github.com/hashicorp/consul/api" "github.com/hashicorp/go-msgpack/codec" "github.com/hashicorp/raft" ) @@ -18,6 +17,28 @@ import ( // msgpackHandle is a shared handle for encoding/decoding msgpack payloads var msgpackHandle = &codec.MsgpackHandle{} +// command is a command method on the FSM. +type command func(buf []byte, index uint64) interface{} + +// unboundCommand is a command method on the FSM, not yet bound to an FSM +// instance. +type unboundCommand func(c *FSM, buf []byte, index uint64) interface{} + +// commands is a map from message type to unbound command. +var commands map[structs.MessageType]unboundCommand + +// registerCommand registers a new command with the FSM, which should be done +// at package init() time. +func registerCommand(msg structs.MessageType, fn unboundCommand) { + if commands == nil { + commands = make(map[structs.MessageType]unboundCommand) + } + if commands[msg] != nil { + panic(fmt.Errorf("Message %d is already registered", msg)) + } + commands[msg] = fn +} + // FSM implements a finite state machine that is used // along with Raft to provide strong consistency. We implement // this outside the Server to avoid exposing this outside the package. @@ -26,6 +47,10 @@ type FSM struct { logger *log.Logger path string + // apply is built off the commands global and is used to route apply + // operations to their appropriate handlers. + apply map[structs.MessageType]command + // stateLock is only used to protect outside callers to State() from // racing with Restore(), which is called by Raft (it puts in a totally // new state store). Everything internal here is synchronized by the @@ -60,9 +85,19 @@ func New(gc *state.TombstoneGC, logOutput io.Writer) (*FSM, error) { fsm := &FSM{ logOutput: logOutput, logger: log.New(logOutput, "", log.LstdFlags), + apply: make(map[structs.MessageType]command), state: stateNew, gc: gc, } + + // Build out the apply dispatch table based on the registered commands. + for msg, fn := range commands { + thisFn := fn + fsm.apply[msg] = func(buf []byte, index uint64) interface{} { + return thisFn(fsm, buf, index) + } + } + return fsm, nil } @@ -86,274 +121,18 @@ func (c *FSM) Apply(log *raft.Log) interface{} { ignoreUnknown = true } - switch msgType { - case structs.RegisterRequestType: - return c.applyRegister(buf[1:], log.Index) - case structs.DeregisterRequestType: - return c.applyDeregister(buf[1:], log.Index) - case structs.KVSRequestType: - return c.applyKVSOperation(buf[1:], log.Index) - case structs.SessionRequestType: - return c.applySessionOperation(buf[1:], log.Index) - case structs.ACLRequestType: - return c.applyACLOperation(buf[1:], log.Index) - case structs.TombstoneRequestType: - return c.applyTombstoneOperation(buf[1:], log.Index) - case structs.CoordinateBatchUpdateType: - return c.applyCoordinateBatchUpdate(buf[1:], log.Index) - case structs.PreparedQueryRequestType: - return c.applyPreparedQueryOperation(buf[1:], log.Index) - case structs.TxnRequestType: - return c.applyTxn(buf[1:], log.Index) - case structs.AutopilotRequestType: - return c.applyAutopilotUpdate(buf[1:], log.Index) - default: - if ignoreUnknown { - c.logger.Printf("[WARN] consul.fsm: ignoring unknown message type (%d), upgrade to newer version", msgType) - return nil - } - panic(fmt.Errorf("failed to apply request: %#v", buf)) - } -} - -func (c *FSM) applyRegister(buf []byte, index uint64) interface{} { - defer metrics.MeasureSince([]string{"consul", "fsm", "register"}, time.Now()) - defer metrics.MeasureSince([]string{"fsm", "register"}, time.Now()) - var req structs.RegisterRequest - if err := structs.Decode(buf, &req); err != nil { - panic(fmt.Errorf("failed to decode request: %v", err)) + // Apply based on the dispatch table, if possible. + if fn, ok := c.apply[msgType]; ok { + return fn(buf[1:], log.Index) } - // Apply all updates in a single transaction - if err := c.state.EnsureRegistration(index, &req); err != nil { - c.logger.Printf("[WARN] consul.fsm: EnsureRegistration failed: %v", err) - return err + // Otherwise, see if it's safe to ignore. If not, we have to panic so + // that we crash and our state doesn't diverge. + if ignoreUnknown { + c.logger.Printf("[WARN] consul.fsm: ignoring unknown message type (%d), upgrade to newer version", msgType) + return nil } - return nil -} - -func (c *FSM) applyDeregister(buf []byte, index uint64) interface{} { - defer metrics.MeasureSince([]string{"consul", "fsm", "deregister"}, time.Now()) - defer metrics.MeasureSince([]string{"fsm", "deregister"}, time.Now()) - var req structs.DeregisterRequest - if err := structs.Decode(buf, &req); err != nil { - panic(fmt.Errorf("failed to decode request: %v", err)) - } - - // Either remove the service entry or the whole node. The precedence - // here is also baked into vetDeregisterWithACL() in acl.go, so if you - // make changes here, be sure to also adjust the code over there. - if req.ServiceID != "" { - if err := c.state.DeleteService(index, req.Node, req.ServiceID); err != nil { - c.logger.Printf("[WARN] consul.fsm: DeleteNodeService failed: %v", err) - return err - } - } else if req.CheckID != "" { - if err := c.state.DeleteCheck(index, req.Node, req.CheckID); err != nil { - c.logger.Printf("[WARN] consul.fsm: DeleteNodeCheck failed: %v", err) - return err - } - } else { - if err := c.state.DeleteNode(index, req.Node); err != nil { - c.logger.Printf("[WARN] consul.fsm: DeleteNode failed: %v", err) - return err - } - } - return nil -} - -func (c *FSM) applyKVSOperation(buf []byte, index uint64) interface{} { - var req structs.KVSRequest - if err := structs.Decode(buf, &req); err != nil { - panic(fmt.Errorf("failed to decode request: %v", err)) - } - defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "kvs"}, time.Now(), - []metrics.Label{{Name: "op", Value: string(req.Op)}}) - defer metrics.MeasureSinceWithLabels([]string{"fsm", "kvs"}, time.Now(), - []metrics.Label{{Name: "op", Value: string(req.Op)}}) - switch req.Op { - case api.KVSet: - return c.state.KVSSet(index, &req.DirEnt) - case api.KVDelete: - return c.state.KVSDelete(index, req.DirEnt.Key) - case api.KVDeleteCAS: - act, err := c.state.KVSDeleteCAS(index, req.DirEnt.ModifyIndex, req.DirEnt.Key) - if err != nil { - return err - } - return act - case api.KVDeleteTree: - return c.state.KVSDeleteTree(index, req.DirEnt.Key) - case api.KVCAS: - act, err := c.state.KVSSetCAS(index, &req.DirEnt) - if err != nil { - return err - } - return act - case api.KVLock: - act, err := c.state.KVSLock(index, &req.DirEnt) - if err != nil { - return err - } - return act - case api.KVUnlock: - act, err := c.state.KVSUnlock(index, &req.DirEnt) - if err != nil { - return err - } - return act - default: - err := fmt.Errorf("Invalid KVS operation '%s'", req.Op) - c.logger.Printf("[WARN] consul.fsm: %v", err) - return err - } -} - -func (c *FSM) applySessionOperation(buf []byte, index uint64) interface{} { - var req structs.SessionRequest - if err := structs.Decode(buf, &req); err != nil { - panic(fmt.Errorf("failed to decode request: %v", err)) - } - defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "session"}, time.Now(), - []metrics.Label{{Name: "op", Value: string(req.Op)}}) - defer metrics.MeasureSinceWithLabels([]string{"fsm", "session"}, time.Now(), - []metrics.Label{{Name: "op", Value: string(req.Op)}}) - switch req.Op { - case structs.SessionCreate: - if err := c.state.SessionCreate(index, &req.Session); err != nil { - return err - } - return req.Session.ID - case structs.SessionDestroy: - return c.state.SessionDestroy(index, req.Session.ID) - default: - c.logger.Printf("[WARN] consul.fsm: Invalid Session operation '%s'", req.Op) - return fmt.Errorf("Invalid Session operation '%s'", req.Op) - } -} - -func (c *FSM) applyACLOperation(buf []byte, index uint64) interface{} { - var req structs.ACLRequest - if err := structs.Decode(buf, &req); err != nil { - panic(fmt.Errorf("failed to decode request: %v", err)) - } - defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "acl"}, time.Now(), - []metrics.Label{{Name: "op", Value: string(req.Op)}}) - defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl"}, time.Now(), - []metrics.Label{{Name: "op", Value: string(req.Op)}}) - switch req.Op { - case structs.ACLBootstrapInit: - enabled, err := c.state.ACLBootstrapInit(index) - if err != nil { - return err - } - return enabled - case structs.ACLBootstrapNow: - if err := c.state.ACLBootstrap(index, &req.ACL); err != nil { - return err - } - return &req.ACL - case structs.ACLForceSet, structs.ACLSet: - if err := c.state.ACLSet(index, &req.ACL); err != nil { - return err - } - return req.ACL.ID - case structs.ACLDelete: - return c.state.ACLDelete(index, req.ACL.ID) - default: - c.logger.Printf("[WARN] consul.fsm: Invalid ACL operation '%s'", req.Op) - return fmt.Errorf("Invalid ACL operation '%s'", req.Op) - } -} - -func (c *FSM) applyTombstoneOperation(buf []byte, index uint64) interface{} { - var req structs.TombstoneRequest - if err := structs.Decode(buf, &req); err != nil { - panic(fmt.Errorf("failed to decode request: %v", err)) - } - defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "tombstone"}, time.Now(), - []metrics.Label{{Name: "op", Value: string(req.Op)}}) - defer metrics.MeasureSinceWithLabels([]string{"fsm", "tombstone"}, time.Now(), - []metrics.Label{{Name: "op", Value: string(req.Op)}}) - switch req.Op { - case structs.TombstoneReap: - return c.state.ReapTombstones(req.ReapIndex) - default: - c.logger.Printf("[WARN] consul.fsm: Invalid Tombstone operation '%s'", req.Op) - return fmt.Errorf("Invalid Tombstone operation '%s'", req.Op) - } -} - -// applyCoordinateBatchUpdate processes a batch of coordinate updates and applies -// them in a single underlying transaction. This interface isn't 1:1 with the outer -// update interface that the coordinate endpoint exposes, so we made it single -// purpose and avoided the opcode convention. -func (c *FSM) applyCoordinateBatchUpdate(buf []byte, index uint64) interface{} { - var updates structs.Coordinates - if err := structs.Decode(buf, &updates); err != nil { - panic(fmt.Errorf("failed to decode batch updates: %v", err)) - } - defer metrics.MeasureSince([]string{"consul", "fsm", "coordinate", "batch-update"}, time.Now()) - defer metrics.MeasureSince([]string{"fsm", "coordinate", "batch-update"}, time.Now()) - if err := c.state.CoordinateBatchUpdate(index, updates); err != nil { - return err - } - return nil -} - -// applyPreparedQueryOperation applies the given prepared query operation to the -// state store. -func (c *FSM) applyPreparedQueryOperation(buf []byte, index uint64) interface{} { - var req structs.PreparedQueryRequest - if err := structs.Decode(buf, &req); err != nil { - panic(fmt.Errorf("failed to decode request: %v", err)) - } - - defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "prepared-query"}, time.Now(), - []metrics.Label{{Name: "op", Value: string(req.Op)}}) - defer metrics.MeasureSinceWithLabels([]string{"fsm", "prepared-query"}, time.Now(), - []metrics.Label{{Name: "op", Value: string(req.Op)}}) - switch req.Op { - case structs.PreparedQueryCreate, structs.PreparedQueryUpdate: - return c.state.PreparedQuerySet(index, req.Query) - case structs.PreparedQueryDelete: - return c.state.PreparedQueryDelete(index, req.Query.ID) - default: - c.logger.Printf("[WARN] consul.fsm: Invalid PreparedQuery operation '%s'", req.Op) - return fmt.Errorf("Invalid PreparedQuery operation '%s'", req.Op) - } -} - -func (c *FSM) applyTxn(buf []byte, index uint64) interface{} { - var req structs.TxnRequest - if err := structs.Decode(buf, &req); err != nil { - panic(fmt.Errorf("failed to decode request: %v", err)) - } - defer metrics.MeasureSince([]string{"consul", "fsm", "txn"}, time.Now()) - defer metrics.MeasureSince([]string{"fsm", "txn"}, time.Now()) - results, errors := c.state.TxnRW(index, req.Ops) - return structs.TxnResponse{ - Results: results, - Errors: errors, - } -} - -func (c *FSM) applyAutopilotUpdate(buf []byte, index uint64) interface{} { - var req structs.AutopilotSetConfigRequest - if err := structs.Decode(buf, &req); err != nil { - panic(fmt.Errorf("failed to decode request: %v", err)) - } - defer metrics.MeasureSince([]string{"consul", "fsm", "autopilot"}, time.Now()) - defer metrics.MeasureSince([]string{"fsm", "autopilot"}, time.Now()) - - if req.CAS { - act, err := c.state.AutopilotCASConfig(index, req.Config.ModifyIndex, &req.Config) - if err != nil { - return err - } - return act - } - return c.state.AutopilotSetConfig(index, &req.Config) + panic(fmt.Errorf("failed to apply request: %#v", buf)) } func (c *FSM) Snapshot() (raft.FSMSnapshot, error) { diff --git a/agent/consul/fsm/fsm_test.go b/agent/consul/fsm/fsm_test.go index 6b8cf9a129..70a91d6798 100644 --- a/agent/consul/fsm/fsm_test.go +++ b/agent/consul/fsm/fsm_test.go @@ -2,8 +2,6 @@ package fsm import ( "bytes" - "fmt" - "math/rand" "os" "reflect" "testing" @@ -13,10 +11,7 @@ import ( "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/lib" - "github.com/hashicorp/consul/types" - "github.com/hashicorp/go-uuid" "github.com/hashicorp/raft" - "github.com/hashicorp/serf/coordinate" "github.com/pascaldekloe/goe/verify" ) @@ -47,333 +42,6 @@ func makeLog(buf []byte) *raft.Log { } } -func generateUUID() (ret string) { - var err error - if ret, err = uuid.GenerateUUID(); err != nil { - panic(fmt.Sprintf("Unable to generate a UUID, %v", err)) - } - return ret -} - -func generateRandomCoordinate() *coordinate.Coordinate { - config := coordinate.DefaultConfig() - coord := coordinate.NewCoordinate(config) - for i := range coord.Vec { - coord.Vec[i] = rand.NormFloat64() - } - coord.Error = rand.NormFloat64() - coord.Adjustment = rand.NormFloat64() - return coord -} - -func TestFSM_RegisterNode(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - req := structs.RegisterRequest{ - Datacenter: "dc1", - Node: "foo", - Address: "127.0.0.1", - } - buf, err := structs.Encode(structs.RegisterRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - // Verify we are registered - _, node, err := fsm.state.GetNode("foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if node == nil { - t.Fatalf("not found!") - } - if node.ModifyIndex != 1 { - t.Fatalf("bad index: %d", node.ModifyIndex) - } - - // Verify service registered - _, services, err := fsm.state.NodeServices(nil, "foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if len(services.Services) != 0 { - t.Fatalf("Services: %v", services) - } -} - -func TestFSM_RegisterNode_Service(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - req := structs.RegisterRequest{ - Datacenter: "dc1", - Node: "foo", - Address: "127.0.0.1", - Service: &structs.NodeService{ - ID: "db", - Service: "db", - Tags: []string{"master"}, - Port: 8000, - }, - Check: &structs.HealthCheck{ - Node: "foo", - CheckID: "db", - Name: "db connectivity", - Status: api.HealthPassing, - ServiceID: "db", - }, - } - buf, err := structs.Encode(structs.RegisterRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - // Verify we are registered - _, node, err := fsm.state.GetNode("foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if node == nil { - t.Fatalf("not found!") - } - - // Verify service registered - _, services, err := fsm.state.NodeServices(nil, "foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if _, ok := services.Services["db"]; !ok { - t.Fatalf("not registered!") - } - - // Verify check - _, checks, err := fsm.state.NodeChecks(nil, "foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if checks[0].CheckID != "db" { - t.Fatalf("not registered!") - } -} - -func TestFSM_DeregisterService(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - req := structs.RegisterRequest{ - Datacenter: "dc1", - Node: "foo", - Address: "127.0.0.1", - Service: &structs.NodeService{ - ID: "db", - Service: "db", - Tags: []string{"master"}, - Port: 8000, - }, - } - buf, err := structs.Encode(structs.RegisterRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - dereg := structs.DeregisterRequest{ - Datacenter: "dc1", - Node: "foo", - ServiceID: "db", - } - buf, err = structs.Encode(structs.DeregisterRequestType, dereg) - if err != nil { - t.Fatalf("err: %v", err) - } - - resp = fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - // Verify we are registered - _, node, err := fsm.state.GetNode("foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if node == nil { - t.Fatalf("not found!") - } - - // Verify service not registered - _, services, err := fsm.state.NodeServices(nil, "foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if _, ok := services.Services["db"]; ok { - t.Fatalf("db registered!") - } -} - -func TestFSM_DeregisterCheck(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - req := structs.RegisterRequest{ - Datacenter: "dc1", - Node: "foo", - Address: "127.0.0.1", - Check: &structs.HealthCheck{ - Node: "foo", - CheckID: "mem", - Name: "memory util", - Status: api.HealthPassing, - }, - } - buf, err := structs.Encode(structs.RegisterRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - dereg := structs.DeregisterRequest{ - Datacenter: "dc1", - Node: "foo", - CheckID: "mem", - } - buf, err = structs.Encode(structs.DeregisterRequestType, dereg) - if err != nil { - t.Fatalf("err: %v", err) - } - - resp = fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - // Verify we are registered - _, node, err := fsm.state.GetNode("foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if node == nil { - t.Fatalf("not found!") - } - - // Verify check not registered - _, checks, err := fsm.state.NodeChecks(nil, "foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if len(checks) != 0 { - t.Fatalf("check registered!") - } -} - -func TestFSM_DeregisterNode(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - req := structs.RegisterRequest{ - Datacenter: "dc1", - Node: "foo", - Address: "127.0.0.1", - Service: &structs.NodeService{ - ID: "db", - Service: "db", - Tags: []string{"master"}, - Port: 8000, - }, - Check: &structs.HealthCheck{ - Node: "foo", - CheckID: "db", - Name: "db connectivity", - Status: api.HealthPassing, - ServiceID: "db", - }, - } - buf, err := structs.Encode(structs.RegisterRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - dereg := structs.DeregisterRequest{ - Datacenter: "dc1", - Node: "foo", - } - buf, err = structs.Encode(structs.DeregisterRequestType, dereg) - if err != nil { - t.Fatalf("err: %v", err) - } - - resp = fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - // Verify we are not registered - _, node, err := fsm.state.GetNode("foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if node != nil { - t.Fatalf("found!") - } - - // Verify service not registered - _, services, err := fsm.state.NodeServices(nil, "foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if services != nil { - t.Fatalf("Services: %v", services) - } - - // Verify checks not registered - _, checks, err := fsm.state.NodeChecks(nil, "foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if len(checks) != 0 { - t.Fatalf("Services: %v", services) - } -} - func TestFSM_SnapshotRestore(t *testing.T) { t.Parallel() fsm, err := New(nil, os.Stderr) @@ -685,811 +353,6 @@ func TestFSM_BadRestore(t *testing.T) { } } -func TestFSM_KVSDelete(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - req := structs.KVSRequest{ - Datacenter: "dc1", - Op: api.KVSet, - DirEnt: structs.DirEntry{ - Key: "/test/path", - Flags: 0, - Value: []byte("test"), - }, - } - buf, err := structs.Encode(structs.KVSRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - // Run the delete - req.Op = api.KVDelete - buf, err = structs.Encode(structs.KVSRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp = fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - // Verify key is not set - _, d, err := fsm.state.KVSGet(nil, "/test/path") - if err != nil { - t.Fatalf("err: %v", err) - } - if d != nil { - t.Fatalf("key present") - } -} - -func TestFSM_KVSDeleteTree(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - req := structs.KVSRequest{ - Datacenter: "dc1", - Op: api.KVSet, - DirEnt: structs.DirEntry{ - Key: "/test/path", - Flags: 0, - Value: []byte("test"), - }, - } - buf, err := structs.Encode(structs.KVSRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - // Run the delete tree - req.Op = api.KVDeleteTree - req.DirEnt.Key = "/test" - buf, err = structs.Encode(structs.KVSRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp = fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - // Verify key is not set - _, d, err := fsm.state.KVSGet(nil, "/test/path") - if err != nil { - t.Fatalf("err: %v", err) - } - if d != nil { - t.Fatalf("key present") - } -} - -func TestFSM_KVSDeleteCheckAndSet(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - req := structs.KVSRequest{ - Datacenter: "dc1", - Op: api.KVSet, - DirEnt: structs.DirEntry{ - Key: "/test/path", - Flags: 0, - Value: []byte("test"), - }, - } - buf, err := structs.Encode(structs.KVSRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - // Verify key is set - _, d, err := fsm.state.KVSGet(nil, "/test/path") - if err != nil { - t.Fatalf("err: %v", err) - } - if d == nil { - t.Fatalf("key missing") - } - - // Run the check-and-set - req.Op = api.KVDeleteCAS - req.DirEnt.ModifyIndex = d.ModifyIndex - buf, err = structs.Encode(structs.KVSRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp = fsm.Apply(makeLog(buf)) - if resp.(bool) != true { - t.Fatalf("resp: %v", resp) - } - - // Verify key is gone - _, d, err = fsm.state.KVSGet(nil, "/test/path") - if err != nil { - t.Fatalf("err: %v", err) - } - if d != nil { - t.Fatalf("bad: %v", d) - } -} - -func TestFSM_KVSCheckAndSet(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - req := structs.KVSRequest{ - Datacenter: "dc1", - Op: api.KVSet, - DirEnt: structs.DirEntry{ - Key: "/test/path", - Flags: 0, - Value: []byte("test"), - }, - } - buf, err := structs.Encode(structs.KVSRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - // Verify key is set - _, d, err := fsm.state.KVSGet(nil, "/test/path") - if err != nil { - t.Fatalf("err: %v", err) - } - if d == nil { - t.Fatalf("key missing") - } - - // Run the check-and-set - req.Op = api.KVCAS - req.DirEnt.ModifyIndex = d.ModifyIndex - req.DirEnt.Value = []byte("zip") - buf, err = structs.Encode(structs.KVSRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp = fsm.Apply(makeLog(buf)) - if resp.(bool) != true { - t.Fatalf("resp: %v", resp) - } - - // Verify key is updated - _, d, err = fsm.state.KVSGet(nil, "/test/path") - if err != nil { - t.Fatalf("err: %v", err) - } - if string(d.Value) != "zip" { - t.Fatalf("bad: %v", d) - } -} - -func TestFSM_CoordinateUpdate(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - // Register some nodes. - fsm.state.EnsureNode(1, &structs.Node{Node: "node1", Address: "127.0.0.1"}) - fsm.state.EnsureNode(2, &structs.Node{Node: "node2", Address: "127.0.0.1"}) - - // Write a batch of two coordinates. - updates := structs.Coordinates{ - &structs.Coordinate{ - Node: "node1", - Coord: generateRandomCoordinate(), - }, - &structs.Coordinate{ - Node: "node2", - Coord: generateRandomCoordinate(), - }, - } - buf, err := structs.Encode(structs.CoordinateBatchUpdateType, updates) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - // Read back the two coordinates to make sure they got updated. - _, coords, err := fsm.state.Coordinates(nil) - if err != nil { - t.Fatalf("err: %s", err) - } - if !reflect.DeepEqual(coords, updates) { - t.Fatalf("bad: %#v", coords) - } -} - -func TestFSM_SessionCreate_Destroy(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) - fsm.state.EnsureCheck(2, &structs.HealthCheck{ - Node: "foo", - CheckID: "web", - Status: api.HealthPassing, - }) - - // Create a new session - req := structs.SessionRequest{ - Datacenter: "dc1", - Op: structs.SessionCreate, - Session: structs.Session{ - ID: generateUUID(), - Node: "foo", - Checks: []types.CheckID{"web"}, - }, - } - buf, err := structs.Encode(structs.SessionRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if err, ok := resp.(error); ok { - t.Fatalf("resp: %v", err) - } - - // Get the session - id := resp.(string) - _, session, err := fsm.state.SessionGet(nil, id) - if err != nil { - t.Fatalf("err: %v", err) - } - if session == nil { - t.Fatalf("missing") - } - - // Verify the session - if session.ID != id { - t.Fatalf("bad: %v", *session) - } - if session.Node != "foo" { - t.Fatalf("bad: %v", *session) - } - if session.Checks[0] != "web" { - t.Fatalf("bad: %v", *session) - } - - // Try to destroy - destroy := structs.SessionRequest{ - Datacenter: "dc1", - Op: structs.SessionDestroy, - Session: structs.Session{ - ID: id, - }, - } - buf, err = structs.Encode(structs.SessionRequestType, destroy) - if err != nil { - t.Fatalf("err: %v", err) - } - resp = fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - _, session, err = fsm.state.SessionGet(nil, id) - if err != nil { - t.Fatalf("err: %v", err) - } - if session != nil { - t.Fatalf("should be destroyed") - } -} - -func TestFSM_KVSLock(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) - session := &structs.Session{ID: generateUUID(), Node: "foo"} - fsm.state.SessionCreate(2, session) - - req := structs.KVSRequest{ - Datacenter: "dc1", - Op: api.KVLock, - DirEnt: structs.DirEntry{ - Key: "/test/path", - Value: []byte("test"), - Session: session.ID, - }, - } - buf, err := structs.Encode(structs.KVSRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if resp != true { - t.Fatalf("resp: %v", resp) - } - - // Verify key is locked - _, d, err := fsm.state.KVSGet(nil, "/test/path") - if err != nil { - t.Fatalf("err: %v", err) - } - if d == nil { - t.Fatalf("missing") - } - if d.LockIndex != 1 { - t.Fatalf("bad: %v", *d) - } - if d.Session != session.ID { - t.Fatalf("bad: %v", *d) - } -} - -func TestFSM_KVSUnlock(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) - session := &structs.Session{ID: generateUUID(), Node: "foo"} - fsm.state.SessionCreate(2, session) - - req := structs.KVSRequest{ - Datacenter: "dc1", - Op: api.KVLock, - DirEnt: structs.DirEntry{ - Key: "/test/path", - Value: []byte("test"), - Session: session.ID, - }, - } - buf, err := structs.Encode(structs.KVSRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if resp != true { - t.Fatalf("resp: %v", resp) - } - - req = structs.KVSRequest{ - Datacenter: "dc1", - Op: api.KVUnlock, - DirEnt: structs.DirEntry{ - Key: "/test/path", - Value: []byte("test"), - Session: session.ID, - }, - } - buf, err = structs.Encode(structs.KVSRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp = fsm.Apply(makeLog(buf)) - if resp != true { - t.Fatalf("resp: %v", resp) - } - - // Verify key is unlocked - _, d, err := fsm.state.KVSGet(nil, "/test/path") - if err != nil { - t.Fatalf("err: %v", err) - } - if d == nil { - t.Fatalf("missing") - } - if d.LockIndex != 1 { - t.Fatalf("bad: %v", *d) - } - if d.Session != "" { - t.Fatalf("bad: %v", *d) - } -} - -func TestFSM_ACL_CRUD(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - // Create a new ACL. - req := structs.ACLRequest{ - Datacenter: "dc1", - Op: structs.ACLSet, - ACL: structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - }, - } - buf, err := structs.Encode(structs.ACLRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if err, ok := resp.(error); ok { - t.Fatalf("resp: %v", err) - } - - // Get the ACL. - id := resp.(string) - _, acl, err := fsm.state.ACLGet(nil, id) - if err != nil { - t.Fatalf("err: %v", err) - } - if acl == nil { - t.Fatalf("missing") - } - - // Verify the ACL. - if acl.ID != id { - t.Fatalf("bad: %v", *acl) - } - if acl.Name != "User token" { - t.Fatalf("bad: %v", *acl) - } - if acl.Type != structs.ACLTypeClient { - t.Fatalf("bad: %v", *acl) - } - - // Try to destroy. - destroy := structs.ACLRequest{ - Datacenter: "dc1", - Op: structs.ACLDelete, - ACL: structs.ACL{ - ID: id, - }, - } - buf, err = structs.Encode(structs.ACLRequestType, destroy) - if err != nil { - t.Fatalf("err: %v", err) - } - resp = fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - - _, acl, err = fsm.state.ACLGet(nil, id) - if err != nil { - t.Fatalf("err: %v", err) - } - if acl != nil { - t.Fatalf("should be destroyed") - } - - // Initialize bootstrap (should work since we haven't made a management - // token). - init := structs.ACLRequest{ - Datacenter: "dc1", - Op: structs.ACLBootstrapInit, - } - buf, err = structs.Encode(structs.ACLRequestType, init) - if err != nil { - t.Fatalf("err: %v", err) - } - resp = fsm.Apply(makeLog(buf)) - if enabled, ok := resp.(bool); !ok || !enabled { - t.Fatalf("resp: %v", resp) - } - gotB, err := fsm.state.ACLGetBootstrap() - if err != nil { - t.Fatalf("err: %v", err) - } - wantB := &structs.ACLBootstrap{ - AllowBootstrap: true, - RaftIndex: gotB.RaftIndex, - } - verify.Values(t, "", gotB, wantB) - - // Do a bootstrap. - bootstrap := structs.ACLRequest{ - Datacenter: "dc1", - Op: structs.ACLBootstrapNow, - ACL: structs.ACL{ - ID: generateUUID(), - Name: "Bootstrap Token", - Type: structs.ACLTypeManagement, - }, - } - buf, err = structs.Encode(structs.ACLRequestType, bootstrap) - if err != nil { - t.Fatalf("err: %v", err) - } - resp = fsm.Apply(makeLog(buf)) - respACL, ok := resp.(*structs.ACL) - if !ok { - t.Fatalf("resp: %v", resp) - } - bootstrap.ACL.CreateIndex = respACL.CreateIndex - bootstrap.ACL.ModifyIndex = respACL.ModifyIndex - verify.Values(t, "", respACL, &bootstrap.ACL) -} - -func TestFSM_PreparedQuery_CRUD(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - // Register a service to query on. - fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) - fsm.state.EnsureService(2, "foo", &structs.NodeService{ID: "web", Service: "web", Tags: nil, Address: "127.0.0.1", Port: 80}) - - // Create a new query. - query := structs.PreparedQueryRequest{ - Op: structs.PreparedQueryCreate, - Query: &structs.PreparedQuery{ - ID: generateUUID(), - Service: structs.ServiceQuery{ - Service: "web", - }, - }, - } - { - buf, err := structs.Encode(structs.PreparedQueryRequestType, query) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - } - - // Verify it's in the state store. - { - _, actual, err := fsm.state.PreparedQueryGet(nil, query.Query.ID) - if err != nil { - t.Fatalf("err: %s", err) - } - - actual.CreateIndex, actual.ModifyIndex = 0, 0 - if !reflect.DeepEqual(actual, query.Query) { - t.Fatalf("bad: %v", actual) - } - } - - // Make an update to the query. - query.Op = structs.PreparedQueryUpdate - query.Query.Name = "my-query" - { - buf, err := structs.Encode(structs.PreparedQueryRequestType, query) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - } - - // Verify the update. - { - _, actual, err := fsm.state.PreparedQueryGet(nil, query.Query.ID) - if err != nil { - t.Fatalf("err: %s", err) - } - - actual.CreateIndex, actual.ModifyIndex = 0, 0 - if !reflect.DeepEqual(actual, query.Query) { - t.Fatalf("bad: %v", actual) - } - } - - // Delete the query. - query.Op = structs.PreparedQueryDelete - { - buf, err := structs.Encode(structs.PreparedQueryRequestType, query) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if resp != nil { - t.Fatalf("resp: %v", resp) - } - } - - // Make sure it's gone. - { - _, actual, err := fsm.state.PreparedQueryGet(nil, query.Query.ID) - if err != nil { - t.Fatalf("err: %s", err) - } - - if actual != nil { - t.Fatalf("bad: %v", actual) - } - } -} - -func TestFSM_TombstoneReap(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - // Create some tombstones - fsm.state.KVSSet(11, &structs.DirEntry{ - Key: "/remove", - Value: []byte("foo"), - }) - fsm.state.KVSDelete(12, "/remove") - idx, _, err := fsm.state.KVSList(nil, "/remove") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 12 { - t.Fatalf("bad index: %d", idx) - } - - // Create a new reap request - req := structs.TombstoneRequest{ - Datacenter: "dc1", - Op: structs.TombstoneReap, - ReapIndex: 12, - } - buf, err := structs.Encode(structs.TombstoneRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if err, ok := resp.(error); ok { - t.Fatalf("resp: %v", err) - } - - // Verify the tombstones are gone - snap := fsm.state.Snapshot() - defer snap.Close() - stones, err := snap.Tombstones() - if err != nil { - t.Fatalf("err: %s", err) - } - if stones.Next() != nil { - t.Fatalf("unexpected extra tombstones") - } -} - -func TestFSM_Txn(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - // Set a key using a transaction. - req := structs.TxnRequest{ - Datacenter: "dc1", - Ops: structs.TxnOps{ - &structs.TxnOp{ - KV: &structs.TxnKVOp{ - Verb: api.KVSet, - DirEnt: structs.DirEntry{ - Key: "/test/path", - Flags: 0, - Value: []byte("test"), - }, - }, - }, - }, - } - buf, err := structs.Encode(structs.TxnRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if _, ok := resp.(structs.TxnResponse); !ok { - t.Fatalf("bad response type: %T", resp) - } - - // Verify key is set directly in the state store. - _, d, err := fsm.state.KVSGet(nil, "/test/path") - if err != nil { - t.Fatalf("err: %v", err) - } - if d == nil { - t.Fatalf("missing") - } -} - -func TestFSM_Autopilot(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - // Set the autopilot config using a request. - req := structs.AutopilotSetConfigRequest{ - Datacenter: "dc1", - Config: structs.AutopilotConfig{ - CleanupDeadServers: true, - LastContactThreshold: 10 * time.Second, - MaxTrailingLogs: 300, - }, - } - buf, err := structs.Encode(structs.AutopilotRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := fsm.Apply(makeLog(buf)) - if _, ok := resp.(error); ok { - t.Fatalf("bad: %v", resp) - } - - // Verify key is set directly in the state store. - _, config, err := fsm.state.AutopilotConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - if config.CleanupDeadServers != req.Config.CleanupDeadServers { - t.Fatalf("bad: %v", config.CleanupDeadServers) - } - if config.LastContactThreshold != req.Config.LastContactThreshold { - t.Fatalf("bad: %v", config.LastContactThreshold) - } - if config.MaxTrailingLogs != req.Config.MaxTrailingLogs { - t.Fatalf("bad: %v", config.MaxTrailingLogs) - } - - // Now use CAS and provide an old index - req.CAS = true - req.Config.CleanupDeadServers = false - req.Config.ModifyIndex = config.ModifyIndex - 1 - buf, err = structs.Encode(structs.AutopilotRequestType, req) - if err != nil { - t.Fatalf("err: %v", err) - } - resp = fsm.Apply(makeLog(buf)) - if _, ok := resp.(error); ok { - t.Fatalf("bad: %v", resp) - } - - _, config, err = fsm.state.AutopilotConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - if !config.CleanupDeadServers { - t.Fatalf("bad: %v", config.CleanupDeadServers) - } -} - func TestFSM_IgnoreUnknown(t *testing.T) { t.Parallel() fsm, err := New(nil, os.Stderr) From f53f5210724d8dd77e43e500d0988c41ca6cd0f8 Mon Sep 17 00:00:00 2001 From: James Phillips Date: Wed, 29 Nov 2017 16:07:18 -0800 Subject: [PATCH 13/14] Begins split out of snapshots from the main FSM class. --- agent/consul/fsm/fsm.go | 287 +--------------------------------- agent/consul/fsm/snapshot.go | 295 +++++++++++++++++++++++++++++++++++ 2 files changed, 296 insertions(+), 286 deletions(-) create mode 100644 agent/consul/fsm/snapshot.go diff --git a/agent/consul/fsm/fsm.go b/agent/consul/fsm/fsm.go index ea71abb607..bc72744df4 100644 --- a/agent/consul/fsm/fsm.go +++ b/agent/consul/fsm/fsm.go @@ -7,7 +7,6 @@ import ( "sync" "time" - "github.com/armon/go-metrics" "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/go-msgpack/codec" @@ -61,20 +60,6 @@ type FSM struct { gc *state.TombstoneGC } -// consulSnapshot is used to provide a snapshot of the current -// state in a way that can be accessed concurrently with operations -// that may modify the live state. -type consulSnapshot struct { - state *state.Snapshot -} - -// snapshotHeader is the first entry in our snapshot -type snapshotHeader struct { - // LastIndex is the last index that affects the data. - // This is used when we do the restore for watchers. - LastIndex uint64 -} - // New is used to construct a new FSM with a blank state. func New(gc *state.TombstoneGC, logOutput io.Writer) (*FSM, error) { stateNew, err := state.NewStateStore(gc) @@ -140,7 +125,7 @@ func (c *FSM) Snapshot() (raft.FSMSnapshot, error) { c.logger.Printf("[INFO] consul.fsm: snapshot created in %v", time.Since(start)) }(time.Now()) - return &consulSnapshot{c.state.Snapshot()}, nil + return &snapshot{c.state.Snapshot()}, nil } // Restore streams in the snapshot and replaces the current state store with a @@ -290,273 +275,3 @@ func (c *FSM) Restore(old io.ReadCloser) error { stateOld.Abandon() return nil } - -func (s *consulSnapshot) Persist(sink raft.SnapshotSink) error { - defer metrics.MeasureSince([]string{"consul", "fsm", "persist"}, time.Now()) - defer metrics.MeasureSince([]string{"fsm", "persist"}, time.Now()) - - // Register the nodes - encoder := codec.NewEncoder(sink, msgpackHandle) - - // Write the header - header := snapshotHeader{ - LastIndex: s.state.LastIndex(), - } - if err := encoder.Encode(&header); err != nil { - sink.Cancel() - return err - } - - if err := s.persistNodes(sink, encoder); err != nil { - sink.Cancel() - return err - } - - if err := s.persistSessions(sink, encoder); err != nil { - sink.Cancel() - return err - } - - if err := s.persistACLs(sink, encoder); err != nil { - sink.Cancel() - return err - } - - if err := s.persistKVs(sink, encoder); err != nil { - sink.Cancel() - return err - } - - if err := s.persistTombstones(sink, encoder); err != nil { - sink.Cancel() - return err - } - - if err := s.persistPreparedQueries(sink, encoder); err != nil { - sink.Cancel() - return err - } - - if err := s.persistAutopilot(sink, encoder); err != nil { - sink.Cancel() - return err - } - - return nil -} - -func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - - // Get all the nodes - nodes, err := s.state.Nodes() - if err != nil { - return err - } - - // Register each node - for node := nodes.Next(); node != nil; node = nodes.Next() { - n := node.(*structs.Node) - req := structs.RegisterRequest{ - Node: n.Node, - Address: n.Address, - TaggedAddresses: n.TaggedAddresses, - } - - // Register the node itself - if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { - return err - } - if err := encoder.Encode(&req); err != nil { - return err - } - - // Register each service this node has - services, err := s.state.Services(n.Node) - if err != nil { - return err - } - for service := services.Next(); service != nil; service = services.Next() { - if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { - return err - } - req.Service = service.(*structs.ServiceNode).ToNodeService() - if err := encoder.Encode(&req); err != nil { - return err - } - } - - // Register each check this node has - req.Service = nil - checks, err := s.state.Checks(n.Node) - if err != nil { - return err - } - for check := checks.Next(); check != nil; check = checks.Next() { - if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { - return err - } - req.Check = check.(*structs.HealthCheck) - if err := encoder.Encode(&req); err != nil { - return err - } - } - } - - // Save the coordinates separately since they are not part of the - // register request interface. To avoid copying them out, we turn - // them into batches with a single coordinate each. - coords, err := s.state.Coordinates() - if err != nil { - return err - } - for coord := coords.Next(); coord != nil; coord = coords.Next() { - if _, err := sink.Write([]byte{byte(structs.CoordinateBatchUpdateType)}); err != nil { - return err - } - updates := structs.Coordinates{coord.(*structs.Coordinate)} - if err := encoder.Encode(&updates); err != nil { - return err - } - } - return nil -} - -func (s *consulSnapshot) persistSessions(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - sessions, err := s.state.Sessions() - if err != nil { - return err - } - - for session := sessions.Next(); session != nil; session = sessions.Next() { - if _, err := sink.Write([]byte{byte(structs.SessionRequestType)}); err != nil { - return err - } - if err := encoder.Encode(session.(*structs.Session)); err != nil { - return err - } - } - return nil -} - -func (s *consulSnapshot) persistACLs(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - acls, err := s.state.ACLs() - if err != nil { - return err - } - - for acl := acls.Next(); acl != nil; acl = acls.Next() { - if _, err := sink.Write([]byte{byte(structs.ACLRequestType)}); err != nil { - return err - } - if err := encoder.Encode(acl.(*structs.ACL)); err != nil { - return err - } - } - - bs, err := s.state.ACLBootstrap() - if err != nil { - return err - } - if bs != nil { - if _, err := sink.Write([]byte{byte(structs.ACLBootstrapRequestType)}); err != nil { - return err - } - if err := encoder.Encode(bs); err != nil { - return err - } - } - - return nil -} - -func (s *consulSnapshot) persistKVs(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - entries, err := s.state.KVs() - if err != nil { - return err - } - - for entry := entries.Next(); entry != nil; entry = entries.Next() { - if _, err := sink.Write([]byte{byte(structs.KVSRequestType)}); err != nil { - return err - } - if err := encoder.Encode(entry.(*structs.DirEntry)); err != nil { - return err - } - } - return nil -} - -func (s *consulSnapshot) persistTombstones(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - stones, err := s.state.Tombstones() - if err != nil { - return err - } - - for stone := stones.Next(); stone != nil; stone = stones.Next() { - if _, err := sink.Write([]byte{byte(structs.TombstoneRequestType)}); err != nil { - return err - } - - // For historical reasons, these are serialized in the snapshots - // as KV entries. We want to keep the snapshot format compatible - // with pre-0.6 versions for now. - s := stone.(*state.Tombstone) - fake := &structs.DirEntry{ - Key: s.Key, - RaftIndex: structs.RaftIndex{ - ModifyIndex: s.Index, - }, - } - if err := encoder.Encode(fake); err != nil { - return err - } - } - return nil -} - -func (s *consulSnapshot) persistPreparedQueries(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - queries, err := s.state.PreparedQueries() - if err != nil { - return err - } - - for _, query := range queries { - if _, err := sink.Write([]byte{byte(structs.PreparedQueryRequestType)}); err != nil { - return err - } - if err := encoder.Encode(query); err != nil { - return err - } - } - return nil -} - -func (s *consulSnapshot) persistAutopilot(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - autopilot, err := s.state.Autopilot() - if err != nil { - return err - } - if autopilot == nil { - return nil - } - - if _, err := sink.Write([]byte{byte(structs.AutopilotRequestType)}); err != nil { - return err - } - if err := encoder.Encode(autopilot); err != nil { - return err - } - - return nil -} - -func (s *consulSnapshot) Release() { - s.state.Close() -} diff --git a/agent/consul/fsm/snapshot.go b/agent/consul/fsm/snapshot.go new file mode 100644 index 0000000000..3735fea89c --- /dev/null +++ b/agent/consul/fsm/snapshot.go @@ -0,0 +1,295 @@ +package fsm + +import ( + "time" + + "github.com/armon/go-metrics" + "github.com/hashicorp/consul/agent/consul/state" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-msgpack/codec" + "github.com/hashicorp/raft" +) + +// snapshot is used to provide a snapshot of the current +// state in a way that can be accessed concurrently with operations +// that may modify the live state. +type snapshot struct { + state *state.Snapshot +} + +// snapshotHeader is the first entry in our snapshot +type snapshotHeader struct { + // LastIndex is the last index that affects the data. + // This is used when we do the restore for watchers. + LastIndex uint64 +} + +func (s *snapshot) Persist(sink raft.SnapshotSink) error { + defer metrics.MeasureSince([]string{"consul", "fsm", "persist"}, time.Now()) + defer metrics.MeasureSince([]string{"fsm", "persist"}, time.Now()) + + // Register the nodes + encoder := codec.NewEncoder(sink, msgpackHandle) + + // Write the header + header := snapshotHeader{ + LastIndex: s.state.LastIndex(), + } + if err := encoder.Encode(&header); err != nil { + sink.Cancel() + return err + } + + if err := s.persistNodes(sink, encoder); err != nil { + sink.Cancel() + return err + } + + if err := s.persistSessions(sink, encoder); err != nil { + sink.Cancel() + return err + } + + if err := s.persistACLs(sink, encoder); err != nil { + sink.Cancel() + return err + } + + if err := s.persistKVs(sink, encoder); err != nil { + sink.Cancel() + return err + } + + if err := s.persistTombstones(sink, encoder); err != nil { + sink.Cancel() + return err + } + + if err := s.persistPreparedQueries(sink, encoder); err != nil { + sink.Cancel() + return err + } + + if err := s.persistAutopilot(sink, encoder); err != nil { + sink.Cancel() + return err + } + + return nil +} + +func (s *snapshot) persistNodes(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + + // Get all the nodes + nodes, err := s.state.Nodes() + if err != nil { + return err + } + + // Register each node + for node := nodes.Next(); node != nil; node = nodes.Next() { + n := node.(*structs.Node) + req := structs.RegisterRequest{ + Node: n.Node, + Address: n.Address, + TaggedAddresses: n.TaggedAddresses, + } + + // Register the node itself + if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { + return err + } + if err := encoder.Encode(&req); err != nil { + return err + } + + // Register each service this node has + services, err := s.state.Services(n.Node) + if err != nil { + return err + } + for service := services.Next(); service != nil; service = services.Next() { + if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { + return err + } + req.Service = service.(*structs.ServiceNode).ToNodeService() + if err := encoder.Encode(&req); err != nil { + return err + } + } + + // Register each check this node has + req.Service = nil + checks, err := s.state.Checks(n.Node) + if err != nil { + return err + } + for check := checks.Next(); check != nil; check = checks.Next() { + if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { + return err + } + req.Check = check.(*structs.HealthCheck) + if err := encoder.Encode(&req); err != nil { + return err + } + } + } + + // Save the coordinates separately since they are not part of the + // register request interface. To avoid copying them out, we turn + // them into batches with a single coordinate each. + coords, err := s.state.Coordinates() + if err != nil { + return err + } + for coord := coords.Next(); coord != nil; coord = coords.Next() { + if _, err := sink.Write([]byte{byte(structs.CoordinateBatchUpdateType)}); err != nil { + return err + } + updates := structs.Coordinates{coord.(*structs.Coordinate)} + if err := encoder.Encode(&updates); err != nil { + return err + } + } + return nil +} + +func (s *snapshot) persistSessions(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + sessions, err := s.state.Sessions() + if err != nil { + return err + } + + for session := sessions.Next(); session != nil; session = sessions.Next() { + if _, err := sink.Write([]byte{byte(structs.SessionRequestType)}); err != nil { + return err + } + if err := encoder.Encode(session.(*structs.Session)); err != nil { + return err + } + } + return nil +} + +func (s *snapshot) persistACLs(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + acls, err := s.state.ACLs() + if err != nil { + return err + } + + for acl := acls.Next(); acl != nil; acl = acls.Next() { + if _, err := sink.Write([]byte{byte(structs.ACLRequestType)}); err != nil { + return err + } + if err := encoder.Encode(acl.(*structs.ACL)); err != nil { + return err + } + } + + bs, err := s.state.ACLBootstrap() + if err != nil { + return err + } + if bs != nil { + if _, err := sink.Write([]byte{byte(structs.ACLBootstrapRequestType)}); err != nil { + return err + } + if err := encoder.Encode(bs); err != nil { + return err + } + } + + return nil +} + +func (s *snapshot) persistKVs(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + entries, err := s.state.KVs() + if err != nil { + return err + } + + for entry := entries.Next(); entry != nil; entry = entries.Next() { + if _, err := sink.Write([]byte{byte(structs.KVSRequestType)}); err != nil { + return err + } + if err := encoder.Encode(entry.(*structs.DirEntry)); err != nil { + return err + } + } + return nil +} + +func (s *snapshot) persistTombstones(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + stones, err := s.state.Tombstones() + if err != nil { + return err + } + + for stone := stones.Next(); stone != nil; stone = stones.Next() { + if _, err := sink.Write([]byte{byte(structs.TombstoneRequestType)}); err != nil { + return err + } + + // For historical reasons, these are serialized in the snapshots + // as KV entries. We want to keep the snapshot format compatible + // with pre-0.6 versions for now. + s := stone.(*state.Tombstone) + fake := &structs.DirEntry{ + Key: s.Key, + RaftIndex: structs.RaftIndex{ + ModifyIndex: s.Index, + }, + } + if err := encoder.Encode(fake); err != nil { + return err + } + } + return nil +} + +func (s *snapshot) persistPreparedQueries(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + queries, err := s.state.PreparedQueries() + if err != nil { + return err + } + + for _, query := range queries { + if _, err := sink.Write([]byte{byte(structs.PreparedQueryRequestType)}); err != nil { + return err + } + if err := encoder.Encode(query); err != nil { + return err + } + } + return nil +} + +func (s *snapshot) persistAutopilot(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + autopilot, err := s.state.Autopilot() + if err != nil { + return err + } + if autopilot == nil { + return nil + } + + if _, err := sink.Write([]byte{byte(structs.AutopilotRequestType)}); err != nil { + return err + } + if err := encoder.Encode(autopilot); err != nil { + return err + } + + return nil +} + +func (s *snapshot) Release() { + s.state.Close() +} From 3e4654408534782ae48ba023ca8df2a175d345e4 Mon Sep 17 00:00:00 2001 From: James Phillips Date: Wed, 29 Nov 2017 17:33:57 -0800 Subject: [PATCH 14/14] Creates a registration mechanism for snapshot and restore. --- agent/consul/fsm/fsm.go | 99 +------ agent/consul/fsm/fsm_test.go | 317 ---------------------- agent/consul/fsm/snapshot.go | 284 +++----------------- agent/consul/fsm/snapshot_oss.go | 365 ++++++++++++++++++++++++++ agent/consul/fsm/snapshot_oss_test.go | 326 +++++++++++++++++++++++ 5 files changed, 734 insertions(+), 657 deletions(-) create mode 100644 agent/consul/fsm/snapshot_oss.go create mode 100644 agent/consul/fsm/snapshot_oss_test.go diff --git a/agent/consul/fsm/fsm.go b/agent/consul/fsm/fsm.go index bc72744df4..87824b8723 100644 --- a/agent/consul/fsm/fsm.go +++ b/agent/consul/fsm/fsm.go @@ -107,7 +107,7 @@ func (c *FSM) Apply(log *raft.Log) interface{} { } // Apply based on the dispatch table, if possible. - if fn, ok := c.apply[msgType]; ok { + if fn := c.apply[msgType]; fn != nil { return fn(buf[1:], log.Index) } @@ -164,102 +164,15 @@ func (c *FSM) Restore(old io.ReadCloser) error { } // Decode - switch structs.MessageType(msgType[0]) { - case structs.RegisterRequestType: - var req structs.RegisterRequest - if err := dec.Decode(&req); err != nil { + msg := structs.MessageType(msgType[0]) + if fn := restorers[msg]; fn != nil { + if err := fn(&header, restore, dec); err != nil { return err } - if err := restore.Registration(header.LastIndex, &req); err != nil { - return err - } - - case structs.KVSRequestType: - var req structs.DirEntry - if err := dec.Decode(&req); err != nil { - return err - } - if err := restore.KVS(&req); err != nil { - return err - } - - case structs.TombstoneRequestType: - var req structs.DirEntry - if err := dec.Decode(&req); err != nil { - return err - } - - // For historical reasons, these are serialized in the - // snapshots as KV entries. We want to keep the snapshot - // format compatible with pre-0.6 versions for now. - stone := &state.Tombstone{ - Key: req.Key, - Index: req.ModifyIndex, - } - if err := restore.Tombstone(stone); err != nil { - return err - } - - case structs.SessionRequestType: - var req structs.Session - if err := dec.Decode(&req); err != nil { - return err - } - if err := restore.Session(&req); err != nil { - return err - } - - case structs.ACLRequestType: - var req structs.ACL - if err := dec.Decode(&req); err != nil { - return err - } - if err := restore.ACL(&req); err != nil { - return err - } - - case structs.ACLBootstrapRequestType: - var req structs.ACLBootstrap - if err := dec.Decode(&req); err != nil { - return err - } - if err := restore.ACLBootstrap(&req); err != nil { - return err - } - - case structs.CoordinateBatchUpdateType: - var req structs.Coordinates - if err := dec.Decode(&req); err != nil { - return err - - } - if err := restore.Coordinates(header.LastIndex, req); err != nil { - return err - } - - case structs.PreparedQueryRequestType: - var req structs.PreparedQuery - if err := dec.Decode(&req); err != nil { - return err - } - if err := restore.PreparedQuery(&req); err != nil { - return err - } - - case structs.AutopilotRequestType: - var req structs.AutopilotConfig - if err := dec.Decode(&req); err != nil { - return err - } - if err := restore.Autopilot(&req); err != nil { - return err - } - - default: - return fmt.Errorf("Unrecognized msg type: %v", msgType) + } else { + return fmt.Errorf("Unrecognized msg type %d", msg) } } - restore.Commit() // External code might be calling State(), so we need to synchronize diff --git a/agent/consul/fsm/fsm_test.go b/agent/consul/fsm/fsm_test.go index 70a91d6798..de763abfba 100644 --- a/agent/consul/fsm/fsm_test.go +++ b/agent/consul/fsm/fsm_test.go @@ -3,16 +3,10 @@ package fsm import ( "bytes" "os" - "reflect" "testing" - "time" - "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/structs" - "github.com/hashicorp/consul/api" - "github.com/hashicorp/consul/lib" "github.com/hashicorp/raft" - "github.com/pascaldekloe/goe/verify" ) type MockSink struct { @@ -42,317 +36,6 @@ func makeLog(buf []byte) *raft.Log { } } -func TestFSM_SnapshotRestore(t *testing.T) { - t.Parallel() - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - // Add some state - fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) - fsm.state.EnsureNode(2, &structs.Node{Node: "baz", Address: "127.0.0.2", TaggedAddresses: map[string]string{"hello": "1.2.3.4"}}) - fsm.state.EnsureService(3, "foo", &structs.NodeService{ID: "web", Service: "web", Tags: nil, Address: "127.0.0.1", Port: 80}) - fsm.state.EnsureService(4, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"primary"}, Address: "127.0.0.1", Port: 5000}) - fsm.state.EnsureService(5, "baz", &structs.NodeService{ID: "web", Service: "web", Tags: nil, Address: "127.0.0.2", Port: 80}) - fsm.state.EnsureService(6, "baz", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"secondary"}, Address: "127.0.0.2", Port: 5000}) - fsm.state.EnsureCheck(7, &structs.HealthCheck{ - Node: "foo", - CheckID: "web", - Name: "web connectivity", - Status: api.HealthPassing, - ServiceID: "web", - }) - fsm.state.KVSSet(8, &structs.DirEntry{ - Key: "/test", - Value: []byte("foo"), - }) - session := &structs.Session{ID: generateUUID(), Node: "foo"} - fsm.state.SessionCreate(9, session) - acl := &structs.ACL{ID: generateUUID(), Name: "User Token"} - fsm.state.ACLSet(10, acl) - if _, err := fsm.state.ACLBootstrapInit(10); err != nil { - t.Fatalf("err: %v", err) - } - - fsm.state.KVSSet(11, &structs.DirEntry{ - Key: "/remove", - Value: []byte("foo"), - }) - fsm.state.KVSDelete(12, "/remove") - idx, _, err := fsm.state.KVSList(nil, "/remove") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 12 { - t.Fatalf("bad index: %d", idx) - } - - updates := structs.Coordinates{ - &structs.Coordinate{ - Node: "baz", - Coord: generateRandomCoordinate(), - }, - &structs.Coordinate{ - Node: "foo", - Coord: generateRandomCoordinate(), - }, - } - if err := fsm.state.CoordinateBatchUpdate(13, updates); err != nil { - t.Fatalf("err: %s", err) - } - - query := structs.PreparedQuery{ - ID: generateUUID(), - Service: structs.ServiceQuery{ - Service: "web", - }, - RaftIndex: structs.RaftIndex{ - CreateIndex: 14, - ModifyIndex: 14, - }, - } - if err := fsm.state.PreparedQuerySet(14, &query); err != nil { - t.Fatalf("err: %s", err) - } - - autopilotConf := &structs.AutopilotConfig{ - CleanupDeadServers: true, - LastContactThreshold: 100 * time.Millisecond, - MaxTrailingLogs: 222, - } - if err := fsm.state.AutopilotSetConfig(15, autopilotConf); err != nil { - t.Fatalf("err: %s", err) - } - - // Snapshot - snap, err := fsm.Snapshot() - if err != nil { - t.Fatalf("err: %v", err) - } - defer snap.Release() - - // Persist - buf := bytes.NewBuffer(nil) - sink := &MockSink{buf, false} - if err := snap.Persist(sink); err != nil { - t.Fatalf("err: %v", err) - } - - // Try to restore on a new FSM - fsm2, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - - // Do a restore - if err := fsm2.Restore(sink); err != nil { - t.Fatalf("err: %v", err) - } - - // Verify the contents - _, nodes, err := fsm2.state.Nodes(nil) - if err != nil { - t.Fatalf("err: %s", err) - } - if len(nodes) != 2 { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Node != "baz" || - nodes[0].Address != "127.0.0.2" || - len(nodes[0].TaggedAddresses) != 1 || - nodes[0].TaggedAddresses["hello"] != "1.2.3.4" { - t.Fatalf("bad: %v", nodes[0]) - } - if nodes[1].Node != "foo" || - nodes[1].Address != "127.0.0.1" || - len(nodes[1].TaggedAddresses) != 0 { - t.Fatalf("bad: %v", nodes[1]) - } - - _, fooSrv, err := fsm2.state.NodeServices(nil, "foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if len(fooSrv.Services) != 2 { - t.Fatalf("Bad: %v", fooSrv) - } - if !lib.StrContains(fooSrv.Services["db"].Tags, "primary") { - t.Fatalf("Bad: %v", fooSrv) - } - if fooSrv.Services["db"].Port != 5000 { - t.Fatalf("Bad: %v", fooSrv) - } - - _, checks, err := fsm2.state.NodeChecks(nil, "foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if len(checks) != 1 { - t.Fatalf("Bad: %v", checks) - } - - // Verify key is set - _, d, err := fsm2.state.KVSGet(nil, "/test") - if err != nil { - t.Fatalf("err: %v", err) - } - if string(d.Value) != "foo" { - t.Fatalf("bad: %v", d) - } - - // Verify session is restored - idx, s, err := fsm2.state.SessionGet(nil, session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if s.Node != "foo" { - t.Fatalf("bad: %v", s) - } - if idx <= 1 { - t.Fatalf("bad index: %d", idx) - } - - // Verify ACL is restored - _, a, err := fsm2.state.ACLGet(nil, acl.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if a.Name != "User Token" { - t.Fatalf("bad: %v", a) - } - if a.ModifyIndex <= 1 { - t.Fatalf("bad index: %d", idx) - } - gotB, err := fsm2.state.ACLGetBootstrap() - if err != nil { - t.Fatalf("err: %v", err) - } - wantB := &structs.ACLBootstrap{ - AllowBootstrap: true, - RaftIndex: structs.RaftIndex{ - CreateIndex: 10, - ModifyIndex: 10, - }, - } - verify.Values(t, "", gotB, wantB) - - // Verify tombstones are restored - func() { - snap := fsm2.state.Snapshot() - defer snap.Close() - stones, err := snap.Tombstones() - if err != nil { - t.Fatalf("err: %s", err) - } - stone := stones.Next().(*state.Tombstone) - if stone == nil { - t.Fatalf("missing tombstone") - } - if stone.Key != "/remove" || stone.Index != 12 { - t.Fatalf("bad: %v", stone) - } - if stones.Next() != nil { - t.Fatalf("unexpected extra tombstones") - } - }() - - // Verify coordinates are restored - _, coords, err := fsm2.state.Coordinates(nil) - if err != nil { - t.Fatalf("err: %s", err) - } - if !reflect.DeepEqual(coords, updates) { - t.Fatalf("bad: %#v", coords) - } - - // Verify queries are restored. - _, queries, err := fsm2.state.PreparedQueryList(nil) - if err != nil { - t.Fatalf("err: %s", err) - } - if len(queries) != 1 { - t.Fatalf("bad: %#v", queries) - } - if !reflect.DeepEqual(queries[0], &query) { - t.Fatalf("bad: %#v", queries[0]) - } - - // Verify autopilot config is restored. - _, restoredConf, err := fsm2.state.AutopilotConfig() - if err != nil { - t.Fatalf("err: %s", err) - } - if !reflect.DeepEqual(restoredConf, autopilotConf) { - t.Fatalf("bad: %#v, %#v", restoredConf, autopilotConf) - } - - // Snapshot - snap, err = fsm2.Snapshot() - if err != nil { - t.Fatalf("err: %v", err) - } - defer snap.Release() - - // Persist - buf = bytes.NewBuffer(nil) - sink = &MockSink{buf, false} - if err := snap.Persist(sink); err != nil { - t.Fatalf("err: %v", err) - } - - // Try to restore on the old FSM and make sure it abandons the old state - // store. - abandonCh := fsm.state.AbandonCh() - if err := fsm.Restore(sink); err != nil { - t.Fatalf("err: %v", err) - } - select { - case <-abandonCh: - default: - t.Fatalf("bad") - } -} - -func TestFSM_BadRestore(t *testing.T) { - t.Parallel() - // Create an FSM with some state. - fsm, err := New(nil, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) - abandonCh := fsm.state.AbandonCh() - - // Do a bad restore. - buf := bytes.NewBuffer([]byte("bad snapshot")) - sink := &MockSink{buf, false} - if err := fsm.Restore(sink); err == nil { - t.Fatalf("err: %v", err) - } - - // Verify the contents didn't get corrupted. - _, nodes, err := fsm.state.Nodes(nil) - if err != nil { - t.Fatalf("err: %s", err) - } - if len(nodes) != 1 { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Node != "foo" || - nodes[0].Address != "127.0.0.1" || - len(nodes[0].TaggedAddresses) != 0 { - t.Fatalf("bad: %v", nodes[0]) - } - - // Verify the old state store didn't get abandoned. - select { - case <-abandonCh: - t.Fatalf("bad") - default: - } -} - func TestFSM_IgnoreUnknown(t *testing.T) { t.Parallel() fsm, err := New(nil, os.Stderr) diff --git a/agent/consul/fsm/snapshot.go b/agent/consul/fsm/snapshot.go index 3735fea89c..3721f07569 100644 --- a/agent/consul/fsm/snapshot.go +++ b/agent/consul/fsm/snapshot.go @@ -1,6 +1,7 @@ package fsm import ( + "fmt" "time" "github.com/armon/go-metrics" @@ -24,272 +25,61 @@ type snapshotHeader struct { LastIndex uint64 } +// persister is a function used to help snapshot the FSM state. +type persister func(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error + +// persisters is a list of snapshot functions. +var persisters []persister + +// registerPersister adds a new helper. This should be called at package +// init() time. +func registerPersister(fn persister) { + persisters = append(persisters, fn) +} + +// restorer is a function used to load back a snapshot of the FSM state. +type restorer func(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error + +// restorers is a map of restore functions by message type. +var restorers map[structs.MessageType]restorer + +// registerRestorer adds a new helper. This should be called at package +// init() time. +func registerRestorer(msg structs.MessageType, fn restorer) { + if restorers == nil { + restorers = make(map[structs.MessageType]restorer) + } + if restorers[msg] != nil { + panic(fmt.Errorf("Message %d is already registered", msg)) + } + restorers[msg] = fn +} + +// Persist saves the FSM snapshot out to the given sink. func (s *snapshot) Persist(sink raft.SnapshotSink) error { defer metrics.MeasureSince([]string{"consul", "fsm", "persist"}, time.Now()) defer metrics.MeasureSince([]string{"fsm", "persist"}, time.Now()) - // Register the nodes - encoder := codec.NewEncoder(sink, msgpackHandle) - // Write the header header := snapshotHeader{ LastIndex: s.state.LastIndex(), } + encoder := codec.NewEncoder(sink, msgpackHandle) if err := encoder.Encode(&header); err != nil { sink.Cancel() return err } - if err := s.persistNodes(sink, encoder); err != nil { - sink.Cancel() - return err - } - - if err := s.persistSessions(sink, encoder); err != nil { - sink.Cancel() - return err - } - - if err := s.persistACLs(sink, encoder); err != nil { - sink.Cancel() - return err - } - - if err := s.persistKVs(sink, encoder); err != nil { - sink.Cancel() - return err - } - - if err := s.persistTombstones(sink, encoder); err != nil { - sink.Cancel() - return err - } - - if err := s.persistPreparedQueries(sink, encoder); err != nil { - sink.Cancel() - return err - } - - if err := s.persistAutopilot(sink, encoder); err != nil { - sink.Cancel() - return err - } - - return nil -} - -func (s *snapshot) persistNodes(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - - // Get all the nodes - nodes, err := s.state.Nodes() - if err != nil { - return err - } - - // Register each node - for node := nodes.Next(); node != nil; node = nodes.Next() { - n := node.(*structs.Node) - req := structs.RegisterRequest{ - Node: n.Node, - Address: n.Address, - TaggedAddresses: n.TaggedAddresses, - } - - // Register the node itself - if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { - return err - } - if err := encoder.Encode(&req); err != nil { - return err - } - - // Register each service this node has - services, err := s.state.Services(n.Node) - if err != nil { - return err - } - for service := services.Next(); service != nil; service = services.Next() { - if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { - return err - } - req.Service = service.(*structs.ServiceNode).ToNodeService() - if err := encoder.Encode(&req); err != nil { - return err - } - } - - // Register each check this node has - req.Service = nil - checks, err := s.state.Checks(n.Node) - if err != nil { - return err - } - for check := checks.Next(); check != nil; check = checks.Next() { - if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { - return err - } - req.Check = check.(*structs.HealthCheck) - if err := encoder.Encode(&req); err != nil { - return err - } - } - } - - // Save the coordinates separately since they are not part of the - // register request interface. To avoid copying them out, we turn - // them into batches with a single coordinate each. - coords, err := s.state.Coordinates() - if err != nil { - return err - } - for coord := coords.Next(); coord != nil; coord = coords.Next() { - if _, err := sink.Write([]byte{byte(structs.CoordinateBatchUpdateType)}); err != nil { - return err - } - updates := structs.Coordinates{coord.(*structs.Coordinate)} - if err := encoder.Encode(&updates); err != nil { + // Run all the persisters to write the FSM state. + for _, fn := range persisters { + if err := fn(s, sink, encoder); err != nil { + sink.Cancel() return err } } return nil } -func (s *snapshot) persistSessions(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - sessions, err := s.state.Sessions() - if err != nil { - return err - } - - for session := sessions.Next(); session != nil; session = sessions.Next() { - if _, err := sink.Write([]byte{byte(structs.SessionRequestType)}); err != nil { - return err - } - if err := encoder.Encode(session.(*structs.Session)); err != nil { - return err - } - } - return nil -} - -func (s *snapshot) persistACLs(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - acls, err := s.state.ACLs() - if err != nil { - return err - } - - for acl := acls.Next(); acl != nil; acl = acls.Next() { - if _, err := sink.Write([]byte{byte(structs.ACLRequestType)}); err != nil { - return err - } - if err := encoder.Encode(acl.(*structs.ACL)); err != nil { - return err - } - } - - bs, err := s.state.ACLBootstrap() - if err != nil { - return err - } - if bs != nil { - if _, err := sink.Write([]byte{byte(structs.ACLBootstrapRequestType)}); err != nil { - return err - } - if err := encoder.Encode(bs); err != nil { - return err - } - } - - return nil -} - -func (s *snapshot) persistKVs(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - entries, err := s.state.KVs() - if err != nil { - return err - } - - for entry := entries.Next(); entry != nil; entry = entries.Next() { - if _, err := sink.Write([]byte{byte(structs.KVSRequestType)}); err != nil { - return err - } - if err := encoder.Encode(entry.(*structs.DirEntry)); err != nil { - return err - } - } - return nil -} - -func (s *snapshot) persistTombstones(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - stones, err := s.state.Tombstones() - if err != nil { - return err - } - - for stone := stones.Next(); stone != nil; stone = stones.Next() { - if _, err := sink.Write([]byte{byte(structs.TombstoneRequestType)}); err != nil { - return err - } - - // For historical reasons, these are serialized in the snapshots - // as KV entries. We want to keep the snapshot format compatible - // with pre-0.6 versions for now. - s := stone.(*state.Tombstone) - fake := &structs.DirEntry{ - Key: s.Key, - RaftIndex: structs.RaftIndex{ - ModifyIndex: s.Index, - }, - } - if err := encoder.Encode(fake); err != nil { - return err - } - } - return nil -} - -func (s *snapshot) persistPreparedQueries(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - queries, err := s.state.PreparedQueries() - if err != nil { - return err - } - - for _, query := range queries { - if _, err := sink.Write([]byte{byte(structs.PreparedQueryRequestType)}); err != nil { - return err - } - if err := encoder.Encode(query); err != nil { - return err - } - } - return nil -} - -func (s *snapshot) persistAutopilot(sink raft.SnapshotSink, - encoder *codec.Encoder) error { - autopilot, err := s.state.Autopilot() - if err != nil { - return err - } - if autopilot == nil { - return nil - } - - if _, err := sink.Write([]byte{byte(structs.AutopilotRequestType)}); err != nil { - return err - } - if err := encoder.Encode(autopilot); err != nil { - return err - } - - return nil -} - func (s *snapshot) Release() { s.state.Close() } diff --git a/agent/consul/fsm/snapshot_oss.go b/agent/consul/fsm/snapshot_oss.go new file mode 100644 index 0000000000..4bf7537222 --- /dev/null +++ b/agent/consul/fsm/snapshot_oss.go @@ -0,0 +1,365 @@ +package fsm + +import ( + "github.com/hashicorp/consul/agent/consul/state" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-msgpack/codec" + "github.com/hashicorp/raft" +) + +func init() { + registerPersister(persistOSS) + + registerRestorer(structs.RegisterRequestType, restoreRegistration) + registerRestorer(structs.KVSRequestType, restoreKV) + registerRestorer(structs.TombstoneRequestType, restoreTombstone) + registerRestorer(structs.SessionRequestType, restoreSession) + registerRestorer(structs.ACLRequestType, restoreACL) + registerRestorer(structs.ACLBootstrapRequestType, restoreACLBootstrap) + registerRestorer(structs.CoordinateBatchUpdateType, restoreCoordinates) + registerRestorer(structs.PreparedQueryRequestType, restorePreparedQuery) + registerRestorer(structs.AutopilotRequestType, restoreAutopilot) +} + +func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error { + if err := s.persistNodes(sink, encoder); err != nil { + return err + } + if err := s.persistSessions(sink, encoder); err != nil { + return err + } + if err := s.persistACLs(sink, encoder); err != nil { + return err + } + if err := s.persistKVs(sink, encoder); err != nil { + return err + } + if err := s.persistTombstones(sink, encoder); err != nil { + return err + } + if err := s.persistPreparedQueries(sink, encoder); err != nil { + return err + } + if err := s.persistAutopilot(sink, encoder); err != nil { + return err + } + return nil +} + +func (s *snapshot) persistNodes(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + + // Get all the nodes + nodes, err := s.state.Nodes() + if err != nil { + return err + } + + // Register each node + for node := nodes.Next(); node != nil; node = nodes.Next() { + n := node.(*structs.Node) + req := structs.RegisterRequest{ + Node: n.Node, + Address: n.Address, + TaggedAddresses: n.TaggedAddresses, + } + + // Register the node itself + if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { + return err + } + if err := encoder.Encode(&req); err != nil { + return err + } + + // Register each service this node has + services, err := s.state.Services(n.Node) + if err != nil { + return err + } + for service := services.Next(); service != nil; service = services.Next() { + if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { + return err + } + req.Service = service.(*structs.ServiceNode).ToNodeService() + if err := encoder.Encode(&req); err != nil { + return err + } + } + + // Register each check this node has + req.Service = nil + checks, err := s.state.Checks(n.Node) + if err != nil { + return err + } + for check := checks.Next(); check != nil; check = checks.Next() { + if _, err := sink.Write([]byte{byte(structs.RegisterRequestType)}); err != nil { + return err + } + req.Check = check.(*structs.HealthCheck) + if err := encoder.Encode(&req); err != nil { + return err + } + } + } + + // Save the coordinates separately since they are not part of the + // register request interface. To avoid copying them out, we turn + // them into batches with a single coordinate each. + coords, err := s.state.Coordinates() + if err != nil { + return err + } + for coord := coords.Next(); coord != nil; coord = coords.Next() { + if _, err := sink.Write([]byte{byte(structs.CoordinateBatchUpdateType)}); err != nil { + return err + } + updates := structs.Coordinates{coord.(*structs.Coordinate)} + if err := encoder.Encode(&updates); err != nil { + return err + } + } + return nil +} + +func (s *snapshot) persistSessions(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + sessions, err := s.state.Sessions() + if err != nil { + return err + } + + for session := sessions.Next(); session != nil; session = sessions.Next() { + if _, err := sink.Write([]byte{byte(structs.SessionRequestType)}); err != nil { + return err + } + if err := encoder.Encode(session.(*structs.Session)); err != nil { + return err + } + } + return nil +} + +func (s *snapshot) persistACLs(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + acls, err := s.state.ACLs() + if err != nil { + return err + } + + for acl := acls.Next(); acl != nil; acl = acls.Next() { + if _, err := sink.Write([]byte{byte(structs.ACLRequestType)}); err != nil { + return err + } + if err := encoder.Encode(acl.(*structs.ACL)); err != nil { + return err + } + } + + bs, err := s.state.ACLBootstrap() + if err != nil { + return err + } + if bs != nil { + if _, err := sink.Write([]byte{byte(structs.ACLBootstrapRequestType)}); err != nil { + return err + } + if err := encoder.Encode(bs); err != nil { + return err + } + } + + return nil +} + +func (s *snapshot) persistKVs(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + entries, err := s.state.KVs() + if err != nil { + return err + } + + for entry := entries.Next(); entry != nil; entry = entries.Next() { + if _, err := sink.Write([]byte{byte(structs.KVSRequestType)}); err != nil { + return err + } + if err := encoder.Encode(entry.(*structs.DirEntry)); err != nil { + return err + } + } + return nil +} + +func (s *snapshot) persistTombstones(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + stones, err := s.state.Tombstones() + if err != nil { + return err + } + + for stone := stones.Next(); stone != nil; stone = stones.Next() { + if _, err := sink.Write([]byte{byte(structs.TombstoneRequestType)}); err != nil { + return err + } + + // For historical reasons, these are serialized in the snapshots + // as KV entries. We want to keep the snapshot format compatible + // with pre-0.6 versions for now. + s := stone.(*state.Tombstone) + fake := &structs.DirEntry{ + Key: s.Key, + RaftIndex: structs.RaftIndex{ + ModifyIndex: s.Index, + }, + } + if err := encoder.Encode(fake); err != nil { + return err + } + } + return nil +} + +func (s *snapshot) persistPreparedQueries(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + queries, err := s.state.PreparedQueries() + if err != nil { + return err + } + + for _, query := range queries { + if _, err := sink.Write([]byte{byte(structs.PreparedQueryRequestType)}); err != nil { + return err + } + if err := encoder.Encode(query); err != nil { + return err + } + } + return nil +} + +func (s *snapshot) persistAutopilot(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + autopilot, err := s.state.Autopilot() + if err != nil { + return err + } + if autopilot == nil { + return nil + } + + if _, err := sink.Write([]byte{byte(structs.AutopilotRequestType)}); err != nil { + return err + } + if err := encoder.Encode(autopilot); err != nil { + return err + } + return nil +} + +func restoreRegistration(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.RegisterRequest + if err := decoder.Decode(&req); err != nil { + return err + } + if err := restore.Registration(header.LastIndex, &req); err != nil { + return err + } + return nil +} + +func restoreKV(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.DirEntry + if err := decoder.Decode(&req); err != nil { + return err + } + if err := restore.KVS(&req); err != nil { + return err + } + return nil +} + +func restoreTombstone(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.DirEntry + if err := decoder.Decode(&req); err != nil { + return err + } + + // For historical reasons, these are serialized in the + // snapshots as KV entries. We want to keep the snapshot + // format compatible with pre-0.6 versions for now. + stone := &state.Tombstone{ + Key: req.Key, + Index: req.ModifyIndex, + } + if err := restore.Tombstone(stone); err != nil { + return err + } + return nil +} + +func restoreSession(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.Session + if err := decoder.Decode(&req); err != nil { + return err + } + if err := restore.Session(&req); err != nil { + return err + } + return nil +} + +func restoreACL(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.ACL + if err := decoder.Decode(&req); err != nil { + return err + } + if err := restore.ACL(&req); err != nil { + return err + } + return nil +} + +func restoreACLBootstrap(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.ACLBootstrap + if err := decoder.Decode(&req); err != nil { + return err + } + if err := restore.ACLBootstrap(&req); err != nil { + return err + } + return nil +} + +func restoreCoordinates(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.Coordinates + if err := decoder.Decode(&req); err != nil { + return err + } + if err := restore.Coordinates(header.LastIndex, req); err != nil { + return err + } + return nil +} + +func restorePreparedQuery(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.PreparedQuery + if err := decoder.Decode(&req); err != nil { + return err + } + if err := restore.PreparedQuery(&req); err != nil { + return err + } + return nil +} + +func restoreAutopilot(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.AutopilotConfig + if err := decoder.Decode(&req); err != nil { + return err + } + if err := restore.Autopilot(&req); err != nil { + return err + } + return nil +} diff --git a/agent/consul/fsm/snapshot_oss_test.go b/agent/consul/fsm/snapshot_oss_test.go new file mode 100644 index 0000000000..ff2419f693 --- /dev/null +++ b/agent/consul/fsm/snapshot_oss_test.go @@ -0,0 +1,326 @@ +package fsm + +import ( + "bytes" + "os" + "reflect" + "testing" + "time" + + "github.com/hashicorp/consul/agent/consul/state" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/lib" + "github.com/pascaldekloe/goe/verify" +) + +func TestFSM_SnapshotRestore_OSS(t *testing.T) { + t.Parallel() + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Add some state + fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + fsm.state.EnsureNode(2, &structs.Node{Node: "baz", Address: "127.0.0.2", TaggedAddresses: map[string]string{"hello": "1.2.3.4"}}) + fsm.state.EnsureService(3, "foo", &structs.NodeService{ID: "web", Service: "web", Tags: nil, Address: "127.0.0.1", Port: 80}) + fsm.state.EnsureService(4, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"primary"}, Address: "127.0.0.1", Port: 5000}) + fsm.state.EnsureService(5, "baz", &structs.NodeService{ID: "web", Service: "web", Tags: nil, Address: "127.0.0.2", Port: 80}) + fsm.state.EnsureService(6, "baz", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"secondary"}, Address: "127.0.0.2", Port: 5000}) + fsm.state.EnsureCheck(7, &structs.HealthCheck{ + Node: "foo", + CheckID: "web", + Name: "web connectivity", + Status: api.HealthPassing, + ServiceID: "web", + }) + fsm.state.KVSSet(8, &structs.DirEntry{ + Key: "/test", + Value: []byte("foo"), + }) + session := &structs.Session{ID: generateUUID(), Node: "foo"} + fsm.state.SessionCreate(9, session) + acl := &structs.ACL{ID: generateUUID(), Name: "User Token"} + fsm.state.ACLSet(10, acl) + if _, err := fsm.state.ACLBootstrapInit(10); err != nil { + t.Fatalf("err: %v", err) + } + + fsm.state.KVSSet(11, &structs.DirEntry{ + Key: "/remove", + Value: []byte("foo"), + }) + fsm.state.KVSDelete(12, "/remove") + idx, _, err := fsm.state.KVSList(nil, "/remove") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 12 { + t.Fatalf("bad index: %d", idx) + } + + updates := structs.Coordinates{ + &structs.Coordinate{ + Node: "baz", + Coord: generateRandomCoordinate(), + }, + &structs.Coordinate{ + Node: "foo", + Coord: generateRandomCoordinate(), + }, + } + if err := fsm.state.CoordinateBatchUpdate(13, updates); err != nil { + t.Fatalf("err: %s", err) + } + + query := structs.PreparedQuery{ + ID: generateUUID(), + Service: structs.ServiceQuery{ + Service: "web", + }, + RaftIndex: structs.RaftIndex{ + CreateIndex: 14, + ModifyIndex: 14, + }, + } + if err := fsm.state.PreparedQuerySet(14, &query); err != nil { + t.Fatalf("err: %s", err) + } + + autopilotConf := &structs.AutopilotConfig{ + CleanupDeadServers: true, + LastContactThreshold: 100 * time.Millisecond, + MaxTrailingLogs: 222, + } + if err := fsm.state.AutopilotSetConfig(15, autopilotConf); err != nil { + t.Fatalf("err: %s", err) + } + + // Snapshot + snap, err := fsm.Snapshot() + if err != nil { + t.Fatalf("err: %v", err) + } + defer snap.Release() + + // Persist + buf := bytes.NewBuffer(nil) + sink := &MockSink{buf, false} + if err := snap.Persist(sink); err != nil { + t.Fatalf("err: %v", err) + } + + // Try to restore on a new FSM + fsm2, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Do a restore + if err := fsm2.Restore(sink); err != nil { + t.Fatalf("err: %v", err) + } + + // Verify the contents + _, nodes, err := fsm2.state.Nodes(nil) + if err != nil { + t.Fatalf("err: %s", err) + } + if len(nodes) != 2 { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Node != "baz" || + nodes[0].Address != "127.0.0.2" || + len(nodes[0].TaggedAddresses) != 1 || + nodes[0].TaggedAddresses["hello"] != "1.2.3.4" { + t.Fatalf("bad: %v", nodes[0]) + } + if nodes[1].Node != "foo" || + nodes[1].Address != "127.0.0.1" || + len(nodes[1].TaggedAddresses) != 0 { + t.Fatalf("bad: %v", nodes[1]) + } + + _, fooSrv, err := fsm2.state.NodeServices(nil, "foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(fooSrv.Services) != 2 { + t.Fatalf("Bad: %v", fooSrv) + } + if !lib.StrContains(fooSrv.Services["db"].Tags, "primary") { + t.Fatalf("Bad: %v", fooSrv) + } + if fooSrv.Services["db"].Port != 5000 { + t.Fatalf("Bad: %v", fooSrv) + } + + _, checks, err := fsm2.state.NodeChecks(nil, "foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(checks) != 1 { + t.Fatalf("Bad: %v", checks) + } + + // Verify key is set + _, d, err := fsm2.state.KVSGet(nil, "/test") + if err != nil { + t.Fatalf("err: %v", err) + } + if string(d.Value) != "foo" { + t.Fatalf("bad: %v", d) + } + + // Verify session is restored + idx, s, err := fsm2.state.SessionGet(nil, session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s.Node != "foo" { + t.Fatalf("bad: %v", s) + } + if idx <= 1 { + t.Fatalf("bad index: %d", idx) + } + + // Verify ACL is restored + _, a, err := fsm2.state.ACLGet(nil, acl.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if a.Name != "User Token" { + t.Fatalf("bad: %v", a) + } + if a.ModifyIndex <= 1 { + t.Fatalf("bad index: %d", idx) + } + gotB, err := fsm2.state.ACLGetBootstrap() + if err != nil { + t.Fatalf("err: %v", err) + } + wantB := &structs.ACLBootstrap{ + AllowBootstrap: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 10, + ModifyIndex: 10, + }, + } + verify.Values(t, "", gotB, wantB) + + // Verify tombstones are restored + func() { + snap := fsm2.state.Snapshot() + defer snap.Close() + stones, err := snap.Tombstones() + if err != nil { + t.Fatalf("err: %s", err) + } + stone := stones.Next().(*state.Tombstone) + if stone == nil { + t.Fatalf("missing tombstone") + } + if stone.Key != "/remove" || stone.Index != 12 { + t.Fatalf("bad: %v", stone) + } + if stones.Next() != nil { + t.Fatalf("unexpected extra tombstones") + } + }() + + // Verify coordinates are restored + _, coords, err := fsm2.state.Coordinates(nil) + if err != nil { + t.Fatalf("err: %s", err) + } + if !reflect.DeepEqual(coords, updates) { + t.Fatalf("bad: %#v", coords) + } + + // Verify queries are restored. + _, queries, err := fsm2.state.PreparedQueryList(nil) + if err != nil { + t.Fatalf("err: %s", err) + } + if len(queries) != 1 { + t.Fatalf("bad: %#v", queries) + } + if !reflect.DeepEqual(queries[0], &query) { + t.Fatalf("bad: %#v", queries[0]) + } + + // Verify autopilot config is restored. + _, restoredConf, err := fsm2.state.AutopilotConfig() + if err != nil { + t.Fatalf("err: %s", err) + } + if !reflect.DeepEqual(restoredConf, autopilotConf) { + t.Fatalf("bad: %#v, %#v", restoredConf, autopilotConf) + } + + // Snapshot + snap, err = fsm2.Snapshot() + if err != nil { + t.Fatalf("err: %v", err) + } + defer snap.Release() + + // Persist + buf = bytes.NewBuffer(nil) + sink = &MockSink{buf, false} + if err := snap.Persist(sink); err != nil { + t.Fatalf("err: %v", err) + } + + // Try to restore on the old FSM and make sure it abandons the old state + // store. + abandonCh := fsm.state.AbandonCh() + if err := fsm.Restore(sink); err != nil { + t.Fatalf("err: %v", err) + } + select { + case <-abandonCh: + default: + t.Fatalf("bad") + } +} + +func TestFSM_BadRestore_OSS(t *testing.T) { + t.Parallel() + // Create an FSM with some state. + fsm, err := New(nil, os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + abandonCh := fsm.state.AbandonCh() + + // Do a bad restore. + buf := bytes.NewBuffer([]byte("bad snapshot")) + sink := &MockSink{buf, false} + if err := fsm.Restore(sink); err == nil { + t.Fatalf("err: %v", err) + } + + // Verify the contents didn't get corrupted. + _, nodes, err := fsm.state.Nodes(nil) + if err != nil { + t.Fatalf("err: %s", err) + } + if len(nodes) != 1 { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Node != "foo" || + nodes[0].Address != "127.0.0.1" || + len(nodes[0].TaggedAddresses) != 0 { + t.Fatalf("bad: %v", nodes[0]) + } + + // Verify the old state store didn't get abandoned. + select { + case <-abandonCh: + t.Fatalf("bad") + default: + } +}