diff --git a/consul/acl_endpoint.go b/consul/acl_endpoint.go index 0cebc9e88a..f5ef90bd82 100644 --- a/consul/acl_endpoint.go +++ b/consul/acl_endpoint.go @@ -123,7 +123,7 @@ func (a *ACL) Get(args *structs.ACLSpecificRequest, state := a.srv.fsm.StateNew() return a.srv.blockingRPCNew(&args.QueryOptions, &reply.QueryMeta, - state.GetTableWatch("acls"), + state.GetQueryWatch("ACLGet"), func() error { acl, err := state.ACLGet(args.ACL) if acl != nil { @@ -194,7 +194,7 @@ func (a *ACL) List(args *structs.DCSpecificRequest, state := a.srv.fsm.StateNew() return a.srv.blockingRPCNew(&args.QueryOptions, &reply.QueryMeta, - state.GetTableWatch("acls"), + state.GetQueryWatch("ACLList"), func() error { var err error reply.Index, reply.ACLs, err = state.ACLList() diff --git a/consul/catalog_endpoint.go b/consul/catalog_endpoint.go index 865372964d..132734ff5e 100644 --- a/consul/catalog_endpoint.go +++ b/consul/catalog_endpoint.go @@ -6,7 +6,6 @@ import ( "time" "github.com/armon/go-metrics" - state_store "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" ) @@ -125,7 +124,7 @@ func (c *Catalog) ListNodes(args *structs.DCSpecificRequest, reply *structs.Inde return c.srv.blockingRPCNew( &args.QueryOptions, &reply.QueryMeta, - state.GetTableWatch("nodes"), + state.GetQueryWatch("Nodes"), func() error { index, nodes, err := state.Nodes() if err != nil { @@ -148,7 +147,7 @@ func (c *Catalog) ListServices(args *structs.DCSpecificRequest, reply *structs.I return c.srv.blockingRPCNew( &args.QueryOptions, &reply.QueryMeta, - state.GetTableWatch("services"), + state.GetQueryWatch("Services"), func() error { index, services, err := state.Services() if err != nil { @@ -176,9 +175,7 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru err := c.srv.blockingRPCNew( &args.QueryOptions, &reply.QueryMeta, - state_store.NewMultiWatch( - state.GetTableWatch("nodes"), - state.GetTableWatch("services")), + state.GetQueryWatch("ServiceNodes"), func() error { var index uint64 var services structs.ServiceNodes @@ -224,9 +221,7 @@ func (c *Catalog) NodeServices(args *structs.NodeSpecificRequest, reply *structs return c.srv.blockingRPCNew( &args.QueryOptions, &reply.QueryMeta, - state_store.NewMultiWatch( - state.GetTableWatch("nodes"), - state.GetTableWatch("services")), + state.GetQueryWatch("NodeServices"), func() error { index, services, err := state.NodeServices(args.Node) if err != nil { diff --git a/consul/health_endpoint.go b/consul/health_endpoint.go index 8f76f7000a..7085ad720d 100644 --- a/consul/health_endpoint.go +++ b/consul/health_endpoint.go @@ -3,7 +3,6 @@ package consul import ( "fmt" "github.com/armon/go-metrics" - state_store "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" ) @@ -24,7 +23,7 @@ func (h *Health) ChecksInState(args *structs.ChecksInStateRequest, return h.srv.blockingRPCNew( &args.QueryOptions, &reply.QueryMeta, - state.GetTableWatch("checks"), + state.GetQueryWatch("ChecksInState"), func() error { index, checks, err := state.ChecksInState(args.State) if err != nil { @@ -47,7 +46,7 @@ func (h *Health) NodeChecks(args *structs.NodeSpecificRequest, return h.srv.blockingRPCNew( &args.QueryOptions, &reply.QueryMeta, - state.GetTableWatch("checks"), + state.GetQueryWatch("NodeChecks"), func() error { index, checks, err := state.NodeChecks(args.Node) if err != nil { @@ -76,7 +75,7 @@ func (h *Health) ServiceChecks(args *structs.ServiceSpecificRequest, return h.srv.blockingRPCNew( &args.QueryOptions, &reply.QueryMeta, - state.GetTableWatch("checks"), + state.GetQueryWatch("ServiceChecks"), func() error { index, checks, err := state.ServiceChecks(args.ServiceName) if err != nil { @@ -103,10 +102,7 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc err := h.srv.blockingRPCNew( &args.QueryOptions, &reply.QueryMeta, - state_store.NewMultiWatch( - state.GetTableWatch("nodes"), - state.GetTableWatch("services"), - state.GetTableWatch("checks")), + state.GetQueryWatch("CheckServiceNodes"), func() error { var index uint64 var nodes structs.CheckServiceNodes diff --git a/consul/session_endpoint.go b/consul/session_endpoint.go index 8675446101..39b923ed49 100644 --- a/consul/session_endpoint.go +++ b/consul/session_endpoint.go @@ -112,7 +112,7 @@ func (s *Session) Get(args *structs.SessionSpecificRequest, return s.srv.blockingRPCNew( &args.QueryOptions, &reply.QueryMeta, - state.GetTableWatch("sessions"), + state.GetQueryWatch("SessionGet"), func() error { index, session, err := state.SessionGet(args.Session) if err != nil { @@ -141,7 +141,7 @@ func (s *Session) List(args *structs.DCSpecificRequest, return s.srv.blockingRPCNew( &args.QueryOptions, &reply.QueryMeta, - state.GetTableWatch("sessions"), + state.GetQueryWatch("SessionList"), func() error { index, sessions, err := state.SessionList() if err != nil { @@ -165,7 +165,7 @@ func (s *Session) NodeSessions(args *structs.NodeSpecificRequest, return s.srv.blockingRPCNew( &args.QueryOptions, &reply.QueryMeta, - state.GetTableWatch("sessions"), + state.GetQueryWatch("NodeSessions"), func() error { index, sessions, err := state.NodeSessions(args.Node) if err != nil { diff --git a/consul/state/state_store.go b/consul/state/state_store.go index 882a469783..b9d871a90e 100644 --- a/consul/state/state_store.go +++ b/consul/state/state_store.go @@ -282,8 +282,9 @@ func (s *StateStore) ReapTombstones(index uint64) error { return nil } -// GetTableWatch returns a watch for the given table. -func (s *StateStore) GetTableWatch(table string) Watch { +// getTableWatch returns a full table watch for the given table. This will panic +// if the table doesn't have a full table watch. +func (s *StateStore) getTableWatch(table string) Watch { if watch, ok := s.tableWatches[table]; ok { return watch } @@ -291,6 +292,32 @@ func (s *StateStore) GetTableWatch(table string) Watch { panic(fmt.Sprintf("Unknown watch for table %#s", table)) } +// GetQueryWatch returns a watch for the given query method. This is +// used for all methods except for KV; you should call GetKVSWatch instead. +func (s *StateStore) GetQueryWatch(method string) Watch { + switch method { + case "GetNode", "Nodes": + return s.getTableWatch("nodes") + case "Services": + return s.getTableWatch("services") + case "ServiceNodes", "NodeServices": + return NewMultiWatch(s.getTableWatch("nodes"), + s.getTableWatch("services")) + case "NodeChecks", "ServiceChecks", "ChecksInState": + return s.getTableWatch("checks") + case "CheckServiceNodes", "NodeInfo", "NodeDump": + return NewMultiWatch(s.getTableWatch("nodes"), + s.getTableWatch("services"), + s.getTableWatch("checks")) + case "SessionGet", "SessionList", "NodeSessions": + return s.getTableWatch("sessions") + case "ACLGet", "ACLList": + return s.getTableWatch("acls") + } + + panic(fmt.Sprintf("Unknown method %#s", method)) +} + // GetKVSWatch returns a watch for the given prefix in the key value store. func (s *StateStore) GetKVSWatch(prefix string) Watch { return s.kvsWatch.GetSubwatch(prefix) diff --git a/consul/state/state_store_test.go b/consul/state/state_store_test.go index 3b6dab0feb..951e65f903 100644 --- a/consul/state/state_store_test.go +++ b/consul/state/state_store_test.go @@ -236,7 +236,7 @@ func TestStateStore_ReapTombstones(t *testing.T) { } } -func TestStateStore_GetTableWatch(t *testing.T) { +func TestStateStore_GetWatches(t *testing.T) { s := testStateStore(t) // This test does two things - it makes sure there's no full table @@ -248,7 +248,7 @@ func TestStateStore_GetTableWatch(t *testing.T) { t.Fatalf("didn't get expected panic") } }() - s.GetTableWatch("kvs") + s.getTableWatch("kvs") }() // Similar for tombstones; those don't support watches at all. @@ -258,8 +258,26 @@ func TestStateStore_GetTableWatch(t *testing.T) { t.Fatalf("didn't get expected panic") } }() - s.GetTableWatch("tombstones") + s.getTableWatch("tombstones") }() + + // Make sure requesting a bogus method causes a panic. + func() { + defer func() { + if r:= recover(); r == nil { + t.Fatalf("didn't get expected panic") + } + }() + s.GetQueryWatch("dogs") + }() + + // Request valid watches. + if w := s.GetQueryWatch("Nodes"); w == nil { + t.Fatalf("didn't get a watch") + } + if w := s.GetKVSWatch("/dogs"); w == nil { + t.Fatalf("didn't get a watch") + } } func TestStateStore_EnsureRegistration(t *testing.T) { @@ -401,9 +419,9 @@ func TestStateStore_EnsureRegistration_Watches(t *testing.T) { } // The nodes watch should fire for this one. - verifyWatch(t, s.GetTableWatch("nodes"), func() { - verifyNoWatch(t, s.GetTableWatch("services"), func() { - verifyNoWatch(t, s.GetTableWatch("checks"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyNoWatch(t, s.getTableWatch("services"), func() { + verifyNoWatch(t, s.getTableWatch("checks"), func() { if err := s.EnsureRegistration(1, req); err != nil { t.Fatalf("err: %s", err) } @@ -419,9 +437,9 @@ func TestStateStore_EnsureRegistration_Watches(t *testing.T) { Address: "1.1.1.1", Port: 8080, } - verifyWatch(t, s.GetTableWatch("nodes"), func() { - verifyWatch(t, s.GetTableWatch("services"), func() { - verifyNoWatch(t, s.GetTableWatch("checks"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyNoWatch(t, s.getTableWatch("checks"), func() { if err := s.EnsureRegistration(2, req); err != nil { t.Fatalf("err: %s", err) } @@ -435,9 +453,9 @@ func TestStateStore_EnsureRegistration_Watches(t *testing.T) { CheckID: "check1", Name: "check", } - verifyWatch(t, s.GetTableWatch("nodes"), func() { - verifyWatch(t, s.GetTableWatch("services"), func() { - verifyWatch(t, s.GetTableWatch("checks"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { if err := s.EnsureRegistration(3, req); err != nil { t.Fatalf("err: %s", err) } @@ -656,7 +674,7 @@ func TestStateStore_Node_Watches(t *testing.T) { // Call functions that update the nodes table and make sure a watch fires // each time. - verifyWatch(t, s.GetTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { req := &structs.RegisterRequest{ Node: "node1", } @@ -664,13 +682,13 @@ func TestStateStore_Node_Watches(t *testing.T) { t.Fatalf("err: %s", err) } }) - verifyWatch(t, s.GetTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { node := &structs.Node{Node: "node2"} if err := s.EnsureNode(2, node); err != nil { t.Fatalf("err: %s", err) } }) - verifyWatch(t, s.GetTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { if err := s.DeleteNode(3, "node2"); err != nil { t.Fatalf("err: %s", err) } @@ -681,9 +699,9 @@ func TestStateStore_Node_Watches(t *testing.T) { testRegisterNode(t, s, 4, "node1") testRegisterService(t, s, 5, "node1", "service1") testRegisterCheck(t, s, 6, "node1", "service1", "check3", structs.HealthPassing) - verifyWatch(t, s.GetTableWatch("nodes"), func() { - verifyWatch(t, s.GetTableWatch("services"), func() { - verifyWatch(t, s.GetTableWatch("checks"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { if err := s.DeleteNode(7, "node1"); err != nil { t.Fatalf("err: %s", err) } @@ -1200,12 +1218,12 @@ func TestStateStore_Service_Watches(t *testing.T) { // Call functions that update the services table and make sure a watch // fires each time. - verifyWatch(t, s.GetTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { if err := s.EnsureService(2, "node1", ns); err != nil { t.Fatalf("err: %s", err) } }) - verifyWatch(t, s.GetTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { if err := s.DeleteService(3, "node1", "service2"); err != nil { t.Fatalf("err: %s", err) } @@ -1215,8 +1233,8 @@ func TestStateStore_Service_Watches(t *testing.T) { // shot. testRegisterService(t, s, 4, "node1", "service1") testRegisterCheck(t, s, 5, "node1", "service1", "check3", structs.HealthPassing) - verifyWatch(t, s.GetTableWatch("services"), func() { - verifyWatch(t, s.GetTableWatch("checks"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { if err := s.DeleteService(6, "node1", "service1"); err != nil { t.Fatalf("err: %s", err) } @@ -1678,18 +1696,18 @@ func TestStateStore_Check_Watches(t *testing.T) { // Call functions that update the checks table and make sure a watch fires // each time. - verifyWatch(t, s.GetTableWatch("checks"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { if err := s.EnsureCheck(1, hc); err != nil { t.Fatalf("err: %s", err) } }) - verifyWatch(t, s.GetTableWatch("checks"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { hc.Status = structs.HealthCritical if err := s.EnsureCheck(2, hc); err != nil { t.Fatalf("err: %s", err) } }) - verifyWatch(t, s.GetTableWatch("checks"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { if err := s.DeleteCheck(3, "node1", "check1"); err != nil { t.Fatalf("err: %s", err) } @@ -3541,7 +3559,7 @@ func TestStateStore_Session_Watches(t *testing.T) { // This just covers the basics. The session invalidation tests above // cover the more nuanced multiple table watches. - verifyWatch(t, s.GetTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("sessions"), func() { session := &structs.Session{ ID: "session1", Node: "node1", @@ -3551,12 +3569,12 @@ func TestStateStore_Session_Watches(t *testing.T) { t.Fatalf("err: %s", err) } }) - verifyWatch(t, s.GetTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("sessions"), func() { if err := s.SessionDestroy(3, "session1"); err != nil { t.Fatalf("err: %s", err) } }) - verifyWatch(t, s.GetTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("sessions"), func() { session := &structs.Session{ ID: "session1", Node: "node1", @@ -3584,8 +3602,8 @@ func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) { } // Delete the node and make sure the watches fire. - verifyWatch(t, s.GetTableWatch("sessions"), func() { - verifyWatch(t, s.GetTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { if err := s.DeleteNode(15, "foo"); err != nil { t.Fatalf("err: %v", err) } @@ -3635,9 +3653,9 @@ func TestStateStore_Session_Invalidate_DeleteService(t *testing.T) { } // Delete the service and make sure the watches fire. - verifyWatch(t, s.GetTableWatch("sessions"), func() { - verifyWatch(t, s.GetTableWatch("services"), func() { - verifyWatch(t, s.GetTableWatch("checks"), func() { + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { if err := s.DeleteService(15, "foo", "api"); err != nil { t.Fatalf("err: %v", err) } @@ -3683,8 +3701,8 @@ func TestStateStore_Session_Invalidate_Critical_Check(t *testing.T) { } // Invalidate the check and make sure the watches fire. - verifyWatch(t, s.GetTableWatch("sessions"), func() { - verifyWatch(t, s.GetTableWatch("checks"), func() { + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { check.Status = structs.HealthCritical if err := s.EnsureCheck(15, check); err != nil { t.Fatalf("err: %v", err) @@ -3730,8 +3748,8 @@ func TestStateStore_Session_Invalidate_DeleteCheck(t *testing.T) { } // Delete the check and make sure the watches fire. - verifyWatch(t, s.GetTableWatch("sessions"), func() { - verifyWatch(t, s.GetTableWatch("checks"), func() { + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { if err := s.DeleteCheck(15, "foo", "bar"); err != nil { t.Fatalf("err: %v", err) } @@ -3794,8 +3812,8 @@ func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) { } // Delete the node and make sure the watches fire. - verifyWatch(t, s.GetTableWatch("sessions"), func() { - verifyWatch(t, s.GetTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { verifyWatch(t, s.GetKVSWatch("/f"), func() { if err := s.DeleteNode(6, "foo"); err != nil { t.Fatalf("err: %v", err) @@ -3871,8 +3889,8 @@ func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) { } // Delete the node and make sure the watches fire. - verifyWatch(t, s.GetTableWatch("sessions"), func() { - verifyWatch(t, s.GetTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { verifyWatch(t, s.GetKVSWatch("/b"), func() { if err := s.DeleteNode(6, "foo"); err != nil { t.Fatalf("err: %v", err) @@ -4171,17 +4189,17 @@ func TestStateStore_ACL_Watches(t *testing.T) { // Call functions that update the acls table and make sure a watch fires // each time. - verifyWatch(t, s.GetTableWatch("acls"), func() { + verifyWatch(t, s.getTableWatch("acls"), func() { if err := s.ACLSet(1, &structs.ACL{ID: "acl1"}); err != nil { t.Fatalf("err: %s", err) } }) - verifyWatch(t, s.GetTableWatch("acls"), func() { + verifyWatch(t, s.getTableWatch("acls"), func() { if err := s.ACLDelete(2, "acl1"); err != nil { t.Fatalf("err: %s", err) } }) - verifyWatch(t, s.GetTableWatch("acls"), func() { + verifyWatch(t, s.getTableWatch("acls"), func() { if err := s.ACLRestore(&structs.ACL{ID: "acl1"}); err != nil { t.Fatalf("err: %s", err) }