diff --git a/command/agent/event_endpoint.go b/command/agent/event_endpoint.go index 84d2550c52..6735d66b64 100644 --- a/command/agent/event_endpoint.go +++ b/command/agent/event_endpoint.go @@ -97,7 +97,7 @@ func (s *HTTPServer) EventList(resp http.ResponseWriter, req *http.Request) (int nameFilter = filt } - // Lots of this logic is borrowed from consul/rpc.go:blockingRPC + // Lots of this logic is borrowed from consul/rpc.go:blockingQuery // However we cannot use that directly since this code has some // slight semantics differences... var timeout <-chan time.Time diff --git a/consul/acl.go b/consul/acl.go index 50b7b675c1..c51a9df799 100644 --- a/consul/acl.go +++ b/consul/acl.go @@ -62,7 +62,7 @@ func (s *Server) aclLocalFault(id string) (string, string, error) { // Query the state store. state := s.fsm.State() - _, acl, err := state.ACLGet(id) + _, acl, err := state.ACLGet(nil, id) if err != nil { return "", "", err } diff --git a/consul/acl_endpoint.go b/consul/acl_endpoint.go index 4f90410da3..97607efa34 100644 --- a/consul/acl_endpoint.go +++ b/consul/acl_endpoint.go @@ -6,7 +6,9 @@ import ( "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-uuid" ) @@ -108,7 +110,7 @@ func (a *ACL) Apply(args *structs.ACLRequest, reply *string) error { return err } - _, acl, err := state.ACLGet(args.ACL.ID) + _, acl, err := state.ACLGet(nil, args.ACL.ID) if err != nil { a.srv.logger.Printf("[ERR] consul.acl: ACL lookup failed: %v", err) return err @@ -144,13 +146,10 @@ func (a *ACL) Get(args *structs.ACLSpecificRequest, return fmt.Errorf(aclDisabled) } - // Get the local state - state := a.srv.fsm.State() - return a.srv.blockingRPC(&args.QueryOptions, + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("ACLGet"), - func() error { - index, acl, err := state.ACLGet(args.ACL) + func(ws memdb.WatchSet, state *state.StateStore) error { + index, acl, err := state.ACLGet(ws, args.ACL) if err != nil { return err } @@ -224,13 +223,10 @@ func (a *ACL) List(args *structs.DCSpecificRequest, return permissionDeniedErr } - // Get the local state - state := a.srv.fsm.State() - return a.srv.blockingRPC(&args.QueryOptions, + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("ACLList"), - func() error { - index, acls, err := state.ACLList() + func(ws memdb.WatchSet, state *state.StateStore) error { + index, acls, err := state.ACLList(ws) if err != nil { return err } diff --git a/consul/acl_endpoint_test.go b/consul/acl_endpoint_test.go index e031730bf6..a59fd43f73 100644 --- a/consul/acl_endpoint_test.go +++ b/consul/acl_endpoint_test.go @@ -41,7 +41,7 @@ func TestACLEndpoint_Apply(t *testing.T) { // Verify state := s1.fsm.State() - _, s, err := state.ACLGet(out) + _, s, err := state.ACLGet(nil, out) if err != nil { t.Fatalf("err: %v", err) } @@ -63,7 +63,7 @@ func TestACLEndpoint_Apply(t *testing.T) { } // Verify - _, s, err = state.ACLGet(id) + _, s, err = state.ACLGet(nil, id) if err != nil { t.Fatalf("err: %v", err) } @@ -182,7 +182,7 @@ func TestACLEndpoint_Apply_CustomID(t *testing.T) { // Verify state := s1.fsm.State() - _, s, err := state.ACLGet(out) + _, s, err := state.ACLGet(nil, out) if err != nil { t.Fatalf("err: %v", err) } diff --git a/consul/acl_replication.go b/consul/acl_replication.go index 57ffce2555..4a89cc59e0 100644 --- a/consul/acl_replication.go +++ b/consul/acl_replication.go @@ -139,7 +139,7 @@ func reconcileACLs(local, remote structs.ACLs, lastRemoteIndex uint64) structs.A // FetchLocalACLs returns the ACLs in the local state store. func (s *Server) fetchLocalACLs() (structs.ACLs, error) { - _, local, err := s.fsm.State().ACLList() + _, local, err := s.fsm.State().ACLList(nil) if err != nil { return nil, err } diff --git a/consul/acl_replication_test.go b/consul/acl_replication_test.go index f7f10752b8..1b419fea58 100644 --- a/consul/acl_replication_test.go +++ b/consul/acl_replication_test.go @@ -364,11 +364,11 @@ func TestACLReplication(t *testing.T) { } checkSame := func() (bool, error) { - index, remote, err := s1.fsm.State().ACLList() + index, remote, err := s1.fsm.State().ACLList(nil) if err != nil { return false, err } - _, local, err := s2.fsm.State().ACLList() + _, local, err := s2.fsm.State().ACLList(nil) if err != nil { return false, err } diff --git a/consul/acl_test.go b/consul/acl_test.go index 65ed6f7aa3..c0f0fd091e 100644 --- a/consul/acl_test.go +++ b/consul/acl_test.go @@ -688,14 +688,14 @@ func TestACL_Replication(t *testing.T) { // Wait for replication to occur. testutil.WaitForResult(func() (bool, error) { - _, acl, err := s2.fsm.State().ACLGet(id) + _, acl, err := s2.fsm.State().ACLGet(nil, id) if err != nil { return false, err } if acl == nil { return false, nil } - _, acl, err = s3.fsm.State().ACLGet(id) + _, acl, err = s3.fsm.State().ACLGet(nil, id) if err != nil { return false, err } diff --git a/consul/catalog_endpoint.go b/consul/catalog_endpoint.go index 8ee16c2a48..1108f5dfee 100644 --- a/consul/catalog_endpoint.go +++ b/consul/catalog_endpoint.go @@ -5,8 +5,10 @@ import ( "time" "github.com/armon/go-metrics" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/types" + "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-uuid" ) @@ -79,7 +81,7 @@ func (c *Catalog) Register(args *structs.RegisterRequest, reply *struct{}) error // Check the complete register request against the given ACL policy. if acl != nil && c.srv.config.ACLEnforceVersion8 { state := c.srv.fsm.State() - _, ns, err := state.NodeServices(args.Node) + _, ns, err := state.NodeServices(nil, args.Node) if err != nil { return fmt.Errorf("Node lookup failed: %v", err) } @@ -162,20 +164,17 @@ func (c *Catalog) ListNodes(args *structs.DCSpecificRequest, reply *structs.Inde return err } - // Get the list of nodes. - state := c.srv.fsm.State() - return c.srv.blockingRPC( + return c.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("Nodes"), - func() error { + func(ws memdb.WatchSet, state *state.StateStore) error { var index uint64 var nodes structs.Nodes var err error if len(args.NodeMetaFilters) > 0 { - index, nodes, err = state.NodesByMeta(args.NodeMetaFilters) + index, nodes, err = state.NodesByMeta(ws, args.NodeMetaFilters) } else { - index, nodes, err = state.Nodes() + index, nodes, err = state.Nodes(ws) } if err != nil { return err @@ -195,20 +194,17 @@ func (c *Catalog) ListServices(args *structs.DCSpecificRequest, reply *structs.I return err } - // Get the list of services and their tags. - state := c.srv.fsm.State() - return c.srv.blockingRPC( + return c.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("Services"), - func() error { + func(ws memdb.WatchSet, state *state.StateStore) error { var index uint64 var services structs.Services var err error if len(args.NodeMetaFilters) > 0 { - index, services, err = state.ServicesByNodeMeta(args.NodeMetaFilters) + index, services, err = state.ServicesByNodeMeta(ws, args.NodeMetaFilters) } else { - index, services, err = state.Services() + index, services, err = state.Services(ws) } if err != nil { return err @@ -230,20 +226,17 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru return fmt.Errorf("Must provide service name") } - // Get the nodes - state := c.srv.fsm.State() - err := c.srv.blockingRPC( + err := c.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("ServiceNodes"), - func() error { + func(ws memdb.WatchSet, state *state.StateStore) error { var index uint64 var services structs.ServiceNodes var err error if args.TagFilter { - index, services, err = state.ServiceTagNodes(args.ServiceName, args.ServiceTag) + index, services, err = state.ServiceTagNodes(ws, args.ServiceName, args.ServiceTag) } else { - index, services, err = state.ServiceNodes(args.ServiceName) + index, services, err = state.ServiceNodes(ws, args.ServiceName) } if err != nil { return err @@ -288,14 +281,11 @@ func (c *Catalog) NodeServices(args *structs.NodeSpecificRequest, reply *structs return fmt.Errorf("Must provide node") } - // Get the node services - state := c.srv.fsm.State() - return c.srv.blockingRPC( + return c.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("NodeServices"), - func() error { - index, services, err := state.NodeServices(args.Node) + func(ws memdb.WatchSet, state *state.StateStore) error { + index, services, err := state.NodeServices(ws, args.Node) if err != nil { return err } diff --git a/consul/coordinate_endpoint.go b/consul/coordinate_endpoint.go index 6a30baa452..b818f904cd 100644 --- a/consul/coordinate_endpoint.go +++ b/consul/coordinate_endpoint.go @@ -7,7 +7,9 @@ import ( "sync" "time" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" "github.com/hashicorp/serf/coordinate" ) @@ -173,12 +175,10 @@ func (c *Coordinate) ListNodes(args *structs.DCSpecificRequest, reply *structs.I return err } - state := c.srv.fsm.State() - return c.srv.blockingRPC(&args.QueryOptions, + return c.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("Coordinates"), - func() error { - index, coords, err := state.Coordinates() + func(ws memdb.WatchSet, state *state.StateStore) error { + index, coords, err := state.Coordinates(ws) if err != nil { return err } diff --git a/consul/fsm.go b/consul/fsm.go index 8e098eb3e1..6bfb102973 100644 --- a/consul/fsm.go +++ b/consul/fsm.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "log" + "sync" "time" "github.com/armon/go-metrics" @@ -24,8 +25,15 @@ type consulFSM struct { logOutput io.Writer logger *log.Logger path string + + // stateLock is only used to protect outside callers to State() from + // racing with Restore(), which is called by Raft (it puts in a totally + // new state store). Everything internal here is synchronized by the + // Raft side, so doesn't need to lock this. + stateLock sync.RWMutex state *state.StateStore - gc *state.TombstoneGC + + gc *state.TombstoneGC } // consulSnapshot is used to provide a snapshot of the current @@ -60,6 +68,8 @@ func NewFSM(gc *state.TombstoneGC, logOutput io.Writer) (*consulFSM, error) { // State is used to return a handle to the current state func (c *consulFSM) State() *state.StateStore { + c.stateLock.RLock() + defer c.stateLock.RUnlock() return c.state } @@ -316,7 +326,18 @@ func (c *consulFSM) Restore(old io.ReadCloser) error { if err != nil { return err } + + // External code might be calling State(), so we need to synchronize + // here to make sure we swap in the new state store atomically. + c.stateLock.Lock() + stateOld := c.state c.state = stateNew + c.stateLock.Unlock() + + // The old state store has been abandoned already since we've replaced + // it with an empty one, but we defer telling watchers about it until + // the restore is done, so they wake up once we have the latest data. + defer stateOld.Abandon() // Set up a new restore transaction restore := c.state.Restore() diff --git a/consul/fsm_test.go b/consul/fsm_test.go index 2f63fd89ac..46608aa075 100644 --- a/consul/fsm_test.go +++ b/consul/fsm_test.go @@ -84,7 +84,7 @@ func TestFSM_RegisterNode(t *testing.T) { } // Verify service registered - _, services, err := fsm.state.NodeServices("foo") + _, services, err := fsm.state.NodeServices(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -137,7 +137,7 @@ func TestFSM_RegisterNode_Service(t *testing.T) { } // Verify service registered - _, services, err := fsm.state.NodeServices("foo") + _, services, err := fsm.state.NodeServices(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -146,7 +146,7 @@ func TestFSM_RegisterNode_Service(t *testing.T) { } // Verify check - _, checks, err := fsm.state.NodeChecks("foo") + _, checks, err := fsm.state.NodeChecks(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -207,7 +207,7 @@ func TestFSM_DeregisterService(t *testing.T) { } // Verify service not registered - _, services, err := fsm.state.NodeServices("foo") + _, services, err := fsm.state.NodeServices(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -268,7 +268,7 @@ func TestFSM_DeregisterCheck(t *testing.T) { } // Verify check not registered - _, checks, err := fsm.state.NodeChecks("foo") + _, checks, err := fsm.state.NodeChecks(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -335,7 +335,7 @@ func TestFSM_DeregisterNode(t *testing.T) { } // Verify service not registered - _, services, err := fsm.state.NodeServices("foo") + _, services, err := fsm.state.NodeServices(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -344,7 +344,7 @@ func TestFSM_DeregisterNode(t *testing.T) { } // Verify checks not registered - _, checks, err := fsm.state.NodeChecks("foo") + _, checks, err := fsm.state.NodeChecks(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -387,7 +387,7 @@ func TestFSM_SnapshotRestore(t *testing.T) { Value: []byte("foo"), }) fsm.state.KVSDelete(12, "/remove") - idx, _, err := fsm.state.KVSList("/remove") + idx, _, err := fsm.state.KVSList(nil, "/remove") if err != nil { t.Fatalf("err: %s", err) } @@ -449,7 +449,7 @@ func TestFSM_SnapshotRestore(t *testing.T) { } // Verify the contents - _, nodes, err := fsm2.state.Nodes() + _, nodes, err := fsm2.state.Nodes(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -468,7 +468,7 @@ func TestFSM_SnapshotRestore(t *testing.T) { t.Fatalf("bad: %v", nodes[1]) } - _, fooSrv, err := fsm2.state.NodeServices("foo") + _, fooSrv, err := fsm2.state.NodeServices(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -482,7 +482,7 @@ func TestFSM_SnapshotRestore(t *testing.T) { t.Fatalf("Bad: %v", fooSrv) } - _, checks, err := fsm2.state.NodeChecks("foo") + _, checks, err := fsm2.state.NodeChecks(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -491,7 +491,7 @@ func TestFSM_SnapshotRestore(t *testing.T) { } // Verify key is set - _, d, err := fsm2.state.KVSGet("/test") + _, d, err := fsm2.state.KVSGet(nil, "/test") if err != nil { t.Fatalf("err: %v", err) } @@ -500,7 +500,7 @@ func TestFSM_SnapshotRestore(t *testing.T) { } // Verify session is restored - idx, s, err := fsm2.state.SessionGet(session.ID) + idx, s, err := fsm2.state.SessionGet(nil, session.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -512,7 +512,7 @@ func TestFSM_SnapshotRestore(t *testing.T) { } // Verify ACL is restored - _, a, err := fsm2.state.ACLGet(acl.ID) + _, a, err := fsm2.state.ACLGet(nil, acl.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -544,7 +544,7 @@ func TestFSM_SnapshotRestore(t *testing.T) { }() // Verify coordinates are restored - _, coords, err := fsm2.state.Coordinates() + _, coords, err := fsm2.state.Coordinates(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -553,7 +553,7 @@ func TestFSM_SnapshotRestore(t *testing.T) { } // Verify queries are restored. - _, queries, err := fsm2.state.PreparedQueryList() + _, queries, err := fsm2.state.PreparedQueryList(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -563,6 +563,33 @@ func TestFSM_SnapshotRestore(t *testing.T) { if !reflect.DeepEqual(queries[0], &query) { t.Fatalf("bad: %#v", queries[0]) } + + // Snapshot + snap, err = fsm2.Snapshot() + if err != nil { + t.Fatalf("err: %v", err) + } + defer snap.Release() + + // Persist + buf = bytes.NewBuffer(nil) + sink = &MockSink{buf, false} + if err := snap.Persist(sink); err != nil { + t.Fatalf("err: %v", err) + } + + // Try to restore on the old FSM and make sure it abandons the old state + // store. + abandonCh := fsm.state.AbandonCh() + if err := fsm.Restore(sink); err != nil { + t.Fatalf("err: %v", err) + } + select { + case <-abandonCh: + default: + t.Fatalf("bad") + } + } func TestFSM_KVSSet(t *testing.T) { @@ -590,7 +617,7 @@ func TestFSM_KVSSet(t *testing.T) { } // Verify key is set - _, d, err := fsm.state.KVSGet("/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path") if err != nil { t.Fatalf("err: %v", err) } @@ -635,7 +662,7 @@ func TestFSM_KVSDelete(t *testing.T) { } // Verify key is not set - _, d, err := fsm.state.KVSGet("/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path") if err != nil { t.Fatalf("err: %v", err) } @@ -681,7 +708,7 @@ func TestFSM_KVSDeleteTree(t *testing.T) { } // Verify key is not set - _, d, err := fsm.state.KVSGet("/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path") if err != nil { t.Fatalf("err: %v", err) } @@ -715,7 +742,7 @@ func TestFSM_KVSDeleteCheckAndSet(t *testing.T) { } // Verify key is set - _, d, err := fsm.state.KVSGet("/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path") if err != nil { t.Fatalf("err: %v", err) } @@ -736,7 +763,7 @@ func TestFSM_KVSDeleteCheckAndSet(t *testing.T) { } // Verify key is gone - _, d, err = fsm.state.KVSGet("/test/path") + _, d, err = fsm.state.KVSGet(nil, "/test/path") if err != nil { t.Fatalf("err: %v", err) } @@ -770,7 +797,7 @@ func TestFSM_KVSCheckAndSet(t *testing.T) { } // Verify key is set - _, d, err := fsm.state.KVSGet("/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path") if err != nil { t.Fatalf("err: %v", err) } @@ -792,7 +819,7 @@ func TestFSM_KVSCheckAndSet(t *testing.T) { } // Verify key is updated - _, d, err = fsm.state.KVSGet("/test/path") + _, d, err = fsm.state.KVSGet(nil, "/test/path") if err != nil { t.Fatalf("err: %v", err) } @@ -832,7 +859,7 @@ func TestFSM_CoordinateUpdate(t *testing.T) { } // Read back the two coordinates to make sure they got updated. - _, coords, err := fsm.state.Coordinates() + _, coords, err := fsm.state.Coordinates(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -875,7 +902,7 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) { // Get the session id := resp.(string) - _, session, err := fsm.state.SessionGet(id) + _, session, err := fsm.state.SessionGet(nil, id) if err != nil { t.Fatalf("err: %v", err) } @@ -911,7 +938,7 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) { t.Fatalf("resp: %v", resp) } - _, session, err = fsm.state.SessionGet(id) + _, session, err = fsm.state.SessionGet(nil, id) if err != nil { t.Fatalf("err: %v", err) } @@ -949,7 +976,7 @@ func TestFSM_KVSLock(t *testing.T) { } // Verify key is locked - _, d, err := fsm.state.KVSGet("/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path") if err != nil { t.Fatalf("err: %v", err) } @@ -1011,7 +1038,7 @@ func TestFSM_KVSUnlock(t *testing.T) { } // Verify key is unlocked - _, d, err := fsm.state.KVSGet("/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path") if err != nil { t.Fatalf("err: %v", err) } @@ -1053,7 +1080,7 @@ func TestFSM_ACL_Set_Delete(t *testing.T) { // Get the ACL id := resp.(string) - _, acl, err := fsm.state.ACLGet(id) + _, acl, err := fsm.state.ACLGet(nil, id) if err != nil { t.Fatalf("err: %v", err) } @@ -1089,7 +1116,7 @@ func TestFSM_ACL_Set_Delete(t *testing.T) { t.Fatalf("resp: %v", resp) } - _, acl, err = fsm.state.ACLGet(id) + _, acl, err = fsm.state.ACLGet(nil, id) if err != nil { t.Fatalf("err: %v", err) } @@ -1131,7 +1158,7 @@ func TestFSM_PreparedQuery_CRUD(t *testing.T) { // Verify it's in the state store. { - _, actual, err := fsm.state.PreparedQueryGet(query.Query.ID) + _, actual, err := fsm.state.PreparedQueryGet(nil, query.Query.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -1158,7 +1185,7 @@ func TestFSM_PreparedQuery_CRUD(t *testing.T) { // Verify the update. { - _, actual, err := fsm.state.PreparedQueryGet(query.Query.ID) + _, actual, err := fsm.state.PreparedQueryGet(nil, query.Query.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -1184,7 +1211,7 @@ func TestFSM_PreparedQuery_CRUD(t *testing.T) { // Make sure it's gone. { - _, actual, err := fsm.state.PreparedQueryGet(query.Query.ID) + _, actual, err := fsm.state.PreparedQueryGet(nil, query.Query.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -1207,7 +1234,7 @@ func TestFSM_TombstoneReap(t *testing.T) { Value: []byte("foo"), }) fsm.state.KVSDelete(12, "/remove") - idx, _, err := fsm.state.KVSList("/remove") + idx, _, err := fsm.state.KVSList(nil, "/remove") if err != nil { t.Fatalf("err: %s", err) } @@ -1274,7 +1301,7 @@ func TestFSM_Txn(t *testing.T) { } // Verify key is set directly in the state store. - _, d, err := fsm.state.KVSGet("/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path") if err != nil { t.Fatalf("err: %v", err) } diff --git a/consul/health_endpoint.go b/consul/health_endpoint.go index 7eb28ce2ac..aa225fb830 100644 --- a/consul/health_endpoint.go +++ b/consul/health_endpoint.go @@ -3,7 +3,9 @@ package consul import ( "fmt" "github.com/armon/go-metrics" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" ) // Health endpoint is used to query the health information @@ -18,20 +20,17 @@ func (h *Health) ChecksInState(args *structs.ChecksInStateRequest, return err } - // Get the state specific checks - state := h.srv.fsm.State() - return h.srv.blockingRPC( + return h.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("ChecksInState"), - func() error { + func(ws memdb.WatchSet, state *state.StateStore) error { var index uint64 var checks structs.HealthChecks var err error if len(args.NodeMetaFilters) > 0 { - index, checks, err = state.ChecksInStateByNodeMeta(args.State, args.NodeMetaFilters) + index, checks, err = state.ChecksInStateByNodeMeta(ws, args.State, args.NodeMetaFilters) } else { - index, checks, err = state.ChecksInState(args.State) + index, checks, err = state.ChecksInState(ws, args.State) } if err != nil { return err @@ -51,14 +50,11 @@ func (h *Health) NodeChecks(args *structs.NodeSpecificRequest, return err } - // Get the node checks - state := h.srv.fsm.State() - return h.srv.blockingRPC( + return h.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("NodeChecks"), - func() error { - index, checks, err := state.NodeChecks(args.Node) + func(ws memdb.WatchSet, state *state.StateStore) error { + index, checks, err := state.NodeChecks(ws, args.Node) if err != nil { return err } @@ -80,20 +76,17 @@ func (h *Health) ServiceChecks(args *structs.ServiceSpecificRequest, return err } - // Get the service checks - state := h.srv.fsm.State() - return h.srv.blockingRPC( + return h.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("ServiceChecks"), - func() error { + func(ws memdb.WatchSet, state *state.StateStore) error { var index uint64 var checks structs.HealthChecks var err error if len(args.NodeMetaFilters) > 0 { - index, checks, err = state.ServiceChecksByNodeMeta(args.ServiceName, args.NodeMetaFilters) + index, checks, err = state.ServiceChecksByNodeMeta(ws, args.ServiceName, args.NodeMetaFilters) } else { - index, checks, err = state.ServiceChecks(args.ServiceName) + index, checks, err = state.ServiceChecks(ws, args.ServiceName) } if err != nil { return err @@ -117,20 +110,17 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc return fmt.Errorf("Must provide service name") } - // Get the nodes - state := h.srv.fsm.State() - err := h.srv.blockingRPC( + err := h.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("CheckServiceNodes"), - func() error { + func(ws memdb.WatchSet, state *state.StateStore) error { var index uint64 var nodes structs.CheckServiceNodes var err error if args.TagFilter { - index, nodes, err = state.CheckServiceTagNodes(args.ServiceName, args.ServiceTag) + index, nodes, err = state.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTag) } else { - index, nodes, err = state.CheckServiceNodes(args.ServiceName) + index, nodes, err = state.CheckServiceNodes(ws, args.ServiceName) } if err != nil { return err diff --git a/consul/internal_endpoint.go b/consul/internal_endpoint.go index a30086f94c..2d0c059619 100644 --- a/consul/internal_endpoint.go +++ b/consul/internal_endpoint.go @@ -3,7 +3,9 @@ package consul import ( "fmt" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" "github.com/hashicorp/serf/serf" ) @@ -21,14 +23,11 @@ func (m *Internal) NodeInfo(args *structs.NodeSpecificRequest, return err } - // Get the node info - state := m.srv.fsm.State() - return m.srv.blockingRPC( + return m.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("NodeInfo"), - func() error { - index, dump, err := state.NodeInfo(args.Node) + func(ws memdb.WatchSet, state *state.StateStore) error { + index, dump, err := state.NodeInfo(ws, args.Node) if err != nil { return err } @@ -45,14 +44,11 @@ func (m *Internal) NodeDump(args *structs.DCSpecificRequest, return err } - // Get all the node info - state := m.srv.fsm.State() - return m.srv.blockingRPC( + return m.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("NodeDump"), - func() error { - index, dump, err := state.NodeDump() + func(ws memdb.WatchSet, state *state.StateStore) error { + index, dump, err := state.NodeDump(ws) if err != nil { return err } diff --git a/consul/issue_test.go b/consul/issue_test.go index 45f9e91c6a..e2fdf9470e 100644 --- a/consul/issue_test.go +++ b/consul/issue_test.go @@ -45,7 +45,7 @@ func TestHealthCheckRace(t *testing.T) { } // Verify the index - idx, out1, err := state.CheckServiceNodes("db") + idx, out1, err := state.CheckServiceNodes(nil, "db") if err != nil { t.Fatalf("err: %s", err) } @@ -68,7 +68,7 @@ func TestHealthCheckRace(t *testing.T) { } // Verify the index changed - idx, out2, err := state.CheckServiceNodes("db") + idx, out2, err := state.CheckServiceNodes(nil, "db") if err != nil { t.Fatalf("err: %s", err) } diff --git a/consul/kvs_endpoint.go b/consul/kvs_endpoint.go index 95ce7576ea..9f0d4cd0c5 100644 --- a/consul/kvs_endpoint.go +++ b/consul/kvs_endpoint.go @@ -6,7 +6,9 @@ import ( "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" ) // KVS endpoint is used to manipulate the Key-Value store @@ -117,14 +119,11 @@ func (k *KVS) Get(args *structs.KeyRequest, reply *structs.IndexedDirEntries) er return err } - // Get the local state - state := k.srv.fsm.State() - return k.srv.blockingRPC( + return k.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetKVSWatch(args.Key), - func() error { - index, ent, err := state.KVSGet(args.Key) + func(ws memdb.WatchSet, state *state.StateStore) error { + index, ent, err := state.KVSGet(ws, args.Key) if err != nil { return err } @@ -159,14 +158,11 @@ func (k *KVS) List(args *structs.KeyRequest, reply *structs.IndexedDirEntries) e return err } - // Get the local state - state := k.srv.fsm.State() - return k.srv.blockingRPC( + return k.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetKVSWatch(args.Key), - func() error { - index, ent, err := state.KVSList(args.Key) + func(ws memdb.WatchSet, state *state.StateStore) error { + index, ent, err := state.KVSList(ws, args.Key) if err != nil { return err } @@ -202,14 +198,11 @@ func (k *KVS) ListKeys(args *structs.KeyListRequest, reply *structs.IndexedKeyLi return err } - // Get the local state - state := k.srv.fsm.State() - return k.srv.blockingRPC( + return k.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetKVSWatch(args.Prefix), - func() error { - index, keys, err := state.KVSListKeys(args.Prefix, args.Seperator) + func(ws memdb.WatchSet, state *state.StateStore) error { + index, keys, err := state.KVSListKeys(ws, args.Prefix, args.Seperator) if err != nil { return err } diff --git a/consul/kvs_endpoint_test.go b/consul/kvs_endpoint_test.go index 50bd58b257..86cd36476c 100644 --- a/consul/kvs_endpoint_test.go +++ b/consul/kvs_endpoint_test.go @@ -36,7 +36,7 @@ func TestKVS_Apply(t *testing.T) { // Verify state := s1.fsm.State() - _, d, err := state.KVSGet("test") + _, d, err := state.KVSGet(nil, "test") if err != nil { t.Fatalf("err: %v", err) } @@ -58,7 +58,7 @@ func TestKVS_Apply(t *testing.T) { } // Verify - _, d, err = state.KVSGet("test") + _, d, err = state.KVSGet(nil, "test") if err != nil { t.Fatalf("err: %v", err) } diff --git a/consul/leader.go b/consul/leader.go index 4efededfee..5cfcec4632 100644 --- a/consul/leader.go +++ b/consul/leader.go @@ -185,7 +185,7 @@ func (s *Server) initializeACL() error { // Look for the anonymous token state := s.fsm.State() - _, acl, err := state.ACLGet(anonymousToken) + _, acl, err := state.ACLGet(nil, anonymousToken) if err != nil { return fmt.Errorf("failed to get anonymous token: %v", err) } @@ -214,7 +214,7 @@ func (s *Server) initializeACL() error { } // Look for the master token - _, acl, err = state.ACLGet(master) + _, acl, err = state.ACLGet(nil, master) if err != nil { return fmt.Errorf("failed to get master token: %v", err) } @@ -262,7 +262,7 @@ func (s *Server) reconcile() (err error) { // a "reap" event to cause the node to be cleaned up. func (s *Server) reconcileReaped(known map[string]struct{}) error { state := s.fsm.State() - _, checks, err := state.ChecksInState(structs.HealthAny) + _, checks, err := state.ChecksInState(nil, structs.HealthAny) if err != nil { return err } @@ -287,7 +287,7 @@ func (s *Server) reconcileReaped(known map[string]struct{}) error { } // Get the node services, look for ConsulServiceID - _, services, err := state.NodeServices(check.Node) + _, services, err := state.NodeServices(nil, check.Node) if err != nil { return err } @@ -385,7 +385,7 @@ func (s *Server) handleAliveMember(member serf.Member) error { // Check if the associated service is available if service != nil { match := false - _, services, err := state.NodeServices(member.Name) + _, services, err := state.NodeServices(nil, member.Name) if err != nil { return err } @@ -402,7 +402,7 @@ func (s *Server) handleAliveMember(member serf.Member) error { } // Check if the serfCheck is in the passing state - _, checks, err := state.NodeChecks(member.Name) + _, checks, err := state.NodeChecks(nil, member.Name) if err != nil { return err } @@ -446,7 +446,7 @@ func (s *Server) handleFailedMember(member serf.Member) error { } if node != nil && node.Address == member.Addr.String() { // Check if the serfCheck is in the critical state - _, checks, err := state.NodeChecks(member.Name) + _, checks, err := state.NodeChecks(nil, member.Name) if err != nil { return err } diff --git a/consul/leader_test.go b/consul/leader_test.go index 6e0f6d5f33..c4312dfbad 100644 --- a/consul/leader_test.go +++ b/consul/leader_test.go @@ -44,7 +44,7 @@ func TestLeader_RegisterMember(t *testing.T) { }) // Should have a check - _, checks, err := state.NodeChecks(c1.config.NodeName) + _, checks, err := state.NodeChecks(nil, c1.config.NodeName) if err != nil { t.Fatalf("err: %v", err) } @@ -71,7 +71,7 @@ func TestLeader_RegisterMember(t *testing.T) { } // Service should be registered - _, services, err := state.NodeServices(s1.config.NodeName) + _, services, err := state.NodeServices(nil, s1.config.NodeName) if err != nil { t.Fatalf("err: %v", err) } @@ -114,7 +114,7 @@ func TestLeader_FailedMember(t *testing.T) { }) // Should have a check - _, checks, err := state.NodeChecks(c1.config.NodeName) + _, checks, err := state.NodeChecks(nil, c1.config.NodeName) if err != nil { t.Fatalf("err: %v", err) } @@ -129,7 +129,7 @@ func TestLeader_FailedMember(t *testing.T) { } testutil.WaitForResult(func() (bool, error) { - _, checks, err = state.NodeChecks(c1.config.NodeName) + _, checks, err = state.NodeChecks(nil, c1.config.NodeName) if err != nil { t.Fatalf("err: %v", err) } diff --git a/consul/prepared_query_endpoint.go b/consul/prepared_query_endpoint.go index d53d8fc021..84ad808148 100644 --- a/consul/prepared_query_endpoint.go +++ b/consul/prepared_query_endpoint.go @@ -8,7 +8,9 @@ import ( "time" "github.com/armon/go-metrics" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-uuid" ) @@ -45,7 +47,7 @@ func (p *PreparedQuery) Apply(args *structs.PreparedQueryRequest, reply *string) if args.Query.ID, err = uuid.GenerateUUID(); err != nil { return fmt.Errorf("UUID generation for prepared query failed: %v", err) } - _, query, err := state.PreparedQueryGet(args.Query.ID) + _, query, err := state.PreparedQueryGet(nil, args.Query.ID) if err != nil { return fmt.Errorf("Prepared query lookup failed: %v", err) } @@ -77,7 +79,7 @@ func (p *PreparedQuery) Apply(args *structs.PreparedQueryRequest, reply *string) // access to whatever they are changing, if prefix ACLs apply to it. if args.Op != structs.PreparedQueryCreate { state := p.srv.fsm.State() - _, query, err := state.PreparedQueryGet(args.Query.ID) + _, query, err := state.PreparedQueryGet(nil, args.Query.ID) if err != nil { return fmt.Errorf("Prepared Query lookup failed: %v", err) } @@ -216,14 +218,11 @@ func (p *PreparedQuery) Get(args *structs.PreparedQuerySpecificRequest, return err } - // Get the requested query. - state := p.srv.fsm.State() - return p.srv.blockingRPC( + return p.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("PreparedQueryGet"), - func() error { - index, query, err := state.PreparedQueryGet(args.QueryID) + func(ws memdb.WatchSet, state *state.StateStore) error { + index, query, err := state.PreparedQueryGet(ws, args.QueryID) if err != nil { return err } @@ -263,14 +262,11 @@ func (p *PreparedQuery) List(args *structs.DCSpecificRequest, reply *structs.Ind return err } - // Get the list of queries. - state := p.srv.fsm.State() - return p.srv.blockingRPC( + return p.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("PreparedQueryList"), - func() error { - index, queries, err := state.PreparedQueryList() + func(ws memdb.WatchSet, state *state.StateStore) error { + index, queries, err := state.PreparedQueryList(ws) if err != nil { return err } @@ -489,7 +485,7 @@ func (p *PreparedQuery) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRe func (p *PreparedQuery) execute(query *structs.PreparedQuery, reply *structs.PreparedQueryExecuteResponse) error { state := p.srv.fsm.State() - _, nodes, err := state.CheckServiceNodes(query.Service.Service) + _, nodes, err := state.CheckServiceNodes(nil, query.Service.Service) if err != nil { return err } diff --git a/consul/rpc.go b/consul/rpc.go index 315b7f1d27..cf7b558902 100644 --- a/consul/rpc.go +++ b/consul/rpc.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/lib" + "github.com/hashicorp/go-memdb" "github.com/hashicorp/memberlist" "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/yamux" @@ -352,23 +353,23 @@ func (s *Server) raftApply(t structs.MessageType, msg interface{}) (interface{}, return future.Response(), nil } -// blockingRPC is used for queries that need to wait for a minimum index. This -// is used to block and wait for changes. -func (s *Server) blockingRPC(queryOpts *structs.QueryOptions, queryMeta *structs.QueryMeta, - watch state.Watch, run func() error) error { +// queryFn is used to perform a query operation. If a re-query is needed, the +// passed-in watch set will be used to block for changes. The passed-in state +// store should be used (vs. calling fsm.State()) since the given state store +// will be correctly watched for changes if the state store is restored from +// a snapshot. +type queryFn func(memdb.WatchSet, *state.StateStore) error + +// blockingQuery is used to process a potentially blocking query operation. +func (s *Server) blockingQuery(queryOpts *structs.QueryOptions, queryMeta *structs.QueryMeta, + fn queryFn) error { var timeout *time.Timer - var notifyCh chan struct{} // Fast path right to the non-blocking query. if queryOpts.MinQueryIndex == 0 { goto RUN_QUERY } - // Make sure a watch was given if we were asked to block. - if watch == nil { - panic("no watch given for blocking query") - } - // Restrict the max query time, and ensure there is always one. if queryOpts.MaxQueryTime > maxQueryTime { queryOpts.MaxQueryTime = maxQueryTime @@ -381,20 +382,7 @@ func (s *Server) blockingRPC(queryOpts *structs.QueryOptions, queryMeta *structs // Setup a query timeout. timeout = time.NewTimer(queryOpts.MaxQueryTime) - - // Setup the notify channel. - notifyCh = make(chan struct{}, 1) - - // Ensure we tear down any watches on return. - defer func() { - timeout.Stop() - watch.Clear(notifyCh) - }() - -REGISTER_NOTIFY: - // Register the notification channel. This may be done multiple times if - // we haven't reached the target wait index. - watch.Wait(notifyCh) + defer timeout.Stop() RUN_QUERY: // Update the query metadata. @@ -409,14 +397,27 @@ RUN_QUERY: // Run the query. metrics.IncrCounter([]string{"consul", "rpc", "query"}, 1) - err := run() - // Check for minimum query time. + // Operate on a consistent set of state. This makes sure that the + // abandon channel goes with the state that the caller is using to + // build watches. + state := s.fsm.State() + + // We can skip all watch tracking if this isn't a blocking query. + var ws memdb.WatchSet + if queryOpts.MinQueryIndex > 0 { + ws = memdb.NewWatchSet() + + // This channel will be closed if a snapshot is restored and the + // whole state store is abandoned. + ws.Add(state.AbandonCh()) + } + + // Block up to the timeout if we didn't see anything fresh. + err := fn(ws, state) if err == nil && queryMeta.Index > 0 && queryMeta.Index <= queryOpts.MinQueryIndex { - select { - case <-notifyCh: - goto REGISTER_NOTIFY - case <-timeout.C: + if expired := ws.Watch(timeout.C); !expired { + goto RUN_QUERY } } return err diff --git a/consul/session_endpoint.go b/consul/session_endpoint.go index 5e6c4ab232..557535c56f 100644 --- a/consul/session_endpoint.go +++ b/consul/session_endpoint.go @@ -5,7 +5,9 @@ import ( "time" "github.com/armon/go-metrics" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-uuid" ) @@ -39,7 +41,7 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { switch args.Op { case structs.SessionDestroy: state := s.srv.fsm.State() - _, existing, err := state.SessionGet(args.Session.ID) + _, existing, err := state.SessionGet(nil, args.Session.ID) if err != nil { return fmt.Errorf("Unknown session %q", args.Session.ID) } @@ -94,7 +96,7 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { s.srv.logger.Printf("[ERR] consul.session: UUID generation failed: %v", err) return err } - _, sess, err := state.SessionGet(args.Session.ID) + _, sess, err := state.SessionGet(nil, args.Session.ID) if err != nil { s.srv.logger.Printf("[ERR] consul.session: Session lookup failed: %v", err) return err @@ -139,14 +141,11 @@ func (s *Session) Get(args *structs.SessionSpecificRequest, return err } - // Get the local state - state := s.srv.fsm.State() - return s.srv.blockingRPC( + return s.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("SessionGet"), - func() error { - index, session, err := state.SessionGet(args.Session) + func(ws memdb.WatchSet, state *state.StateStore) error { + index, session, err := state.SessionGet(ws, args.Session) if err != nil { return err } @@ -171,14 +170,11 @@ func (s *Session) List(args *structs.DCSpecificRequest, return err } - // Get the local state - state := s.srv.fsm.State() - return s.srv.blockingRPC( + return s.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("SessionList"), - func() error { - index, sessions, err := state.SessionList() + func(ws memdb.WatchSet, state *state.StateStore) error { + index, sessions, err := state.SessionList(ws) if err != nil { return err } @@ -198,14 +194,11 @@ func (s *Session) NodeSessions(args *structs.NodeSpecificRequest, return err } - // Get the local state - state := s.srv.fsm.State() - return s.srv.blockingRPC( + return s.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, - state.GetQueryWatch("NodeSessions"), - func() error { - index, sessions, err := state.NodeSessions(args.Node) + func(ws memdb.WatchSet, state *state.StateStore) error { + index, sessions, err := state.NodeSessions(ws, args.Node) if err != nil { return err } @@ -228,7 +221,7 @@ func (s *Session) Renew(args *structs.SessionSpecificRequest, // Get the session, from local state. state := s.srv.fsm.State() - index, session, err := state.SessionGet(args.Session) + index, session, err := state.SessionGet(nil, args.Session) if err != nil { return err } diff --git a/consul/session_endpoint_test.go b/consul/session_endpoint_test.go index 275a53aadb..326c36c1a5 100644 --- a/consul/session_endpoint_test.go +++ b/consul/session_endpoint_test.go @@ -40,7 +40,7 @@ func TestSession_Apply(t *testing.T) { // Verify state := s1.fsm.State() - _, s, err := state.SessionGet(out) + _, s, err := state.SessionGet(nil, out) if err != nil { t.Fatalf("err: %v", err) } @@ -62,7 +62,7 @@ func TestSession_Apply(t *testing.T) { } // Verify - _, s, err = state.SessionGet(id) + _, s, err = state.SessionGet(nil, id) if err != nil { t.Fatalf("err: %v", err) } @@ -100,7 +100,7 @@ func TestSession_DeleteApply(t *testing.T) { // Verify state := s1.fsm.State() - _, s, err := state.SessionGet(out) + _, s, err := state.SessionGet(nil, out) if err != nil { t.Fatalf("err: %v", err) } @@ -125,7 +125,7 @@ func TestSession_DeleteApply(t *testing.T) { } // Verify - _, s, err = state.SessionGet(id) + _, s, err = state.SessionGet(nil, id) if err != nil { t.Fatalf("err: %v", err) } diff --git a/consul/session_ttl.go b/consul/session_ttl.go index c8ce91283c..c4398fe63f 100644 --- a/consul/session_ttl.go +++ b/consul/session_ttl.go @@ -22,7 +22,7 @@ const ( func (s *Server) initializeSessionTimers() error { // Scan all sessions and reset their timer state := s.fsm.State() - _, sessions, err := state.SessionList() + _, sessions, err := state.SessionList(nil) if err != nil { return err } @@ -41,7 +41,7 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error { // Fault the session in if not given if session == nil { state := s.fsm.State() - _, s, err := state.SessionGet(id) + _, s, err := state.SessionGet(nil, id) if err != nil { return err } diff --git a/consul/session_ttl_test.go b/consul/session_ttl_test.go index 7ef647beb4..880211899c 100644 --- a/consul/session_ttl_test.go +++ b/consul/session_ttl_test.go @@ -225,7 +225,7 @@ func TestInvalidateSession(t *testing.T) { s1.invalidateSession(session.ID) // Check it is gone - _, sess, err := state.SessionGet(session.ID) + _, sess, err := state.SessionGet(nil, session.ID) if err != nil { t.Fatalf("err: %v", err) } diff --git a/consul/state/acl.go b/consul/state/acl.go index 3ce94e9a13..c99600fe85 100644 --- a/consul/state/acl.go +++ b/consul/state/acl.go @@ -26,7 +26,6 @@ func (s *StateRestore) ACL(acl *structs.ACL) error { return fmt.Errorf("failed updating index: %s", err) } - s.watches.Arm("acls") return nil } @@ -75,23 +74,24 @@ func (s *StateStore) aclSetTxn(tx *memdb.Txn, idx uint64, acl *structs.ACL) erro return fmt.Errorf("failed updating index: %s", err) } - tx.Defer(func() { s.tableWatches["acls"].Notify() }) return nil } // ACLGet is used to look up an existing ACL by ID. -func (s *StateStore) ACLGet(aclID string) (uint64, *structs.ACL, error) { +func (s *StateStore) ACLGet(ws memdb.WatchSet, aclID string) (uint64, *structs.ACL, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ACLGet")...) + idx := maxIndexTxn(tx, "acls") // Query for the existing ACL - acl, err := tx.First("acls", "id", aclID) + watchCh, acl, err := tx.FirstWatch("acls", "id", aclID) if err != nil { return 0, nil, fmt.Errorf("failed acl lookup: %s", err) } + ws.Add(watchCh) + if acl != nil { return idx, acl.(*structs.ACL), nil } @@ -99,15 +99,15 @@ func (s *StateStore) ACLGet(aclID string) (uint64, *structs.ACL, error) { } // ACLList is used to list out all of the ACLs in the state store. -func (s *StateStore) ACLList() (uint64, structs.ACLs, error) { +func (s *StateStore) ACLList(ws memdb.WatchSet) (uint64, structs.ACLs, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ACLList")...) + idx := maxIndexTxn(tx, "acls") // Return the ACLs. - acls, err := s.aclListTxn(tx) + acls, err := s.aclListTxn(tx, ws) if err != nil { return 0, nil, fmt.Errorf("failed acl lookup: %s", err) } @@ -116,16 +116,17 @@ func (s *StateStore) ACLList() (uint64, structs.ACLs, error) { // aclListTxn is used to list out all of the ACLs in the state store. This is a // function vs. a method so it can be called from the snapshotter. -func (s *StateStore) aclListTxn(tx *memdb.Txn) (structs.ACLs, error) { +func (s *StateStore) aclListTxn(tx *memdb.Txn, ws memdb.WatchSet) (structs.ACLs, error) { // Query all of the ACLs in the state store - acls, err := tx.Get("acls", "id") + iter, err := tx.Get("acls", "id") if err != nil { return nil, fmt.Errorf("failed acl lookup: %s", err) } + ws.Add(iter.WatchCh()) // Go over all of the ACLs and build the response var result structs.ACLs - for acl := acls.Next(); acl != nil; acl = acls.Next() { + for acl := iter.Next(); acl != nil; acl = iter.Next() { a := acl.(*structs.ACL) result = append(result, a) } @@ -167,6 +168,5 @@ func (s *StateStore) aclDeleteTxn(tx *memdb.Txn, idx uint64, aclID string) error return fmt.Errorf("failed updating index: %s", err) } - tx.Defer(func() { s.tableWatches["acls"].Notify() }) return nil } diff --git a/consul/state/acl_test.go b/consul/state/acl_test.go index 94bab3fedd..2d7bb21396 100644 --- a/consul/state/acl_test.go +++ b/consul/state/acl_test.go @@ -5,13 +5,15 @@ import ( "testing" "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" ) func TestStateStore_ACLSet_ACLGet(t *testing.T) { s := testStateStore(t) // Querying ACLs with no results returns nil - idx, res, err := s.ACLGet("nope") + ws := memdb.NewWatchSet() + idx, res, err := s.ACLGet(ws, "nope") if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } @@ -20,6 +22,9 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) { if err := s.ACLSet(1, &structs.ACL{}); err == nil { t.Fatalf("expected %#v, got: %#v", ErrMissingACLID, err) } + if watchFired(ws) { + t.Fatalf("bad") + } // Index is not updated if nothing is saved if idx := s.maxIndex("acls"); idx != 0 { @@ -36,6 +41,9 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) { if err := s.ACLSet(1, acl); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Check that the index was updated if idx := s.maxIndex("acls"); idx != 1 { @@ -43,7 +51,8 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) { } // Retrieve the ACL again - idx, result, err := s.ACLGet("acl1") + ws = memdb.NewWatchSet() + idx, result, err := s.ACLGet(ws, "acl1") if err != nil { t.Fatalf("err: %s", err) } @@ -76,6 +85,9 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) { if err := s.ACLSet(2, acl); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Index was updated if idx := s.maxIndex("acls"); idx != 2 { @@ -102,7 +114,8 @@ func TestStateStore_ACLList(t *testing.T) { s := testStateStore(t) // Listing when no ACLs exist returns nil - idx, res, err := s.ACLList() + ws := memdb.NewWatchSet() + idx, res, err := s.ACLList(ws) if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } @@ -133,9 +146,12 @@ func TestStateStore_ACLList(t *testing.T) { t.Fatalf("err: %s", err) } } + if !watchFired(ws) { + t.Fatalf("bad") + } // Query the ACLs - idx, res, err = s.ACLList() + idx, res, err = s.ACLList(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -255,7 +271,7 @@ func TestStateStore_ACL_Snapshot_Restore(t *testing.T) { restore.Commit() // Read the restored ACLs back out and verify that they match. - idx, res, err := s.ACLList() + idx, res, err := s.ACLList(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -272,27 +288,3 @@ func TestStateStore_ACL_Snapshot_Restore(t *testing.T) { } }() } - -func TestStateStore_ACL_Watches(t *testing.T) { - s := testStateStore(t) - - // Call functions that update the acls table and make sure a watch fires - // each time. - verifyWatch(t, s.getTableWatch("acls"), func() { - if err := s.ACLSet(1, &structs.ACL{ID: "acl1"}); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("acls"), func() { - if err := s.ACLDelete(2, "acl1"); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("acls"), func() { - restore := s.Restore() - if err := restore.ACL(&structs.ACL{ID: "acl1"}); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - }) -} diff --git a/consul/state/catalog.go b/consul/state/catalog.go index a6c4448e33..f10a351898 100644 --- a/consul/state/catalog.go +++ b/consul/state/catalog.go @@ -42,7 +42,7 @@ func (s *StateSnapshot) Checks(node string) (memdb.ResultIterator, error) { // performed within a single transaction to avoid race conditions on state // updates. func (s *StateRestore) Registration(idx uint64, req *structs.RegisterRequest) error { - if err := s.store.ensureRegistrationTxn(s.tx, idx, s.watches, req); err != nil { + if err := s.store.ensureRegistrationTxn(s.tx, idx, req); err != nil { return err } return nil @@ -55,12 +55,10 @@ func (s *StateStore) EnsureRegistration(idx uint64, req *structs.RegisterRequest tx := s.db.Txn(true) defer tx.Abort() - watches := NewDumbWatchManager(s.tableWatches) - if err := s.ensureRegistrationTxn(tx, idx, watches, req); err != nil { + if err := s.ensureRegistrationTxn(tx, idx, req); err != nil { return err } - tx.Defer(func() { watches.Notify() }) tx.Commit() return nil } @@ -68,9 +66,8 @@ func (s *StateStore) EnsureRegistration(idx uint64, req *structs.RegisterRequest // ensureRegistrationTxn is used to make sure a node, service, and check // registration is performed within a single transaction to avoid race // conditions on state updates. -func (s *StateStore) ensureRegistrationTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, - req *structs.RegisterRequest) error { - // Add the node. +func (s *StateStore) ensureRegistrationTxn(tx *memdb.Txn, idx uint64, req *structs.RegisterRequest) error { + // Create a node structure. node := &structs.Node{ ID: req.ID, Node: req.Node, @@ -78,14 +75,37 @@ func (s *StateStore) ensureRegistrationTxn(tx *memdb.Txn, idx uint64, watches *D TaggedAddresses: req.TaggedAddresses, Meta: req.NodeMeta, } - if err := s.ensureNodeTxn(tx, idx, watches, node); err != nil { - return fmt.Errorf("failed inserting node: %s", err) + + // Since this gets called for all node operations (service and check + // updates) and churn on the node itself is basically none after the + // node updates itself the first time, it's worth seeing if we need to + // modify the node at all so we prevent watch churn and useless writes + // and modify index bumps on the node. + { + existing, err := tx.First("nodes", "id", node.Node) + if err != nil { + return fmt.Errorf("node lookup failed: %s", err) + } + if existing == nil || req.ChangesNode(existing.(*structs.Node)) { + if err := s.ensureNodeTxn(tx, idx, node); err != nil { + return fmt.Errorf("failed inserting node: %s", err) + } + } } - // Add the service, if any. + // Add the service, if any. We perform a similar check as we do for the + // node info above to make sure we actually need to update the service + // definition in order to prevent useless churn if nothing has changed. if req.Service != nil { - if err := s.ensureServiceTxn(tx, idx, watches, req.Node, req.Service); err != nil { - return fmt.Errorf("failed inserting service: %s", err) + existing, err := tx.First("services", "id", req.Node, req.Service.ID) + if err != nil { + return fmt.Errorf("failed service lookup: %s", err) + } + if existing == nil || !(existing.(*structs.ServiceNode).ToNodeService()).IsSame(req.Service) { + if err := s.ensureServiceTxn(tx, idx, req.Node, req.Service); err != nil { + return fmt.Errorf("failed inserting service: %s", err) + + } } } @@ -97,12 +117,12 @@ func (s *StateStore) ensureRegistrationTxn(tx *memdb.Txn, idx uint64, watches *D // Add the checks, if any. if req.Check != nil { - if err := s.ensureCheckTxn(tx, idx, watches, req.Check); err != nil { + if err := s.ensureCheckTxn(tx, idx, req.Check); err != nil { return fmt.Errorf("failed inserting check: %s", err) } } for _, check := range req.Checks { - if err := s.ensureCheckTxn(tx, idx, watches, check); err != nil { + if err := s.ensureCheckTxn(tx, idx, check); err != nil { return fmt.Errorf("failed inserting check: %s", err) } } @@ -116,12 +136,10 @@ func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error { defer tx.Abort() // Call the node upsert - watches := NewDumbWatchManager(s.tableWatches) - if err := s.ensureNodeTxn(tx, idx, watches, node); err != nil { + if err := s.ensureNodeTxn(tx, idx, node); err != nil { return err } - tx.Defer(func() { watches.Notify() }) tx.Commit() return nil } @@ -129,8 +147,7 @@ func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error { // ensureNodeTxn is the inner function called to actually create a node // registration or modify an existing one in the state store. It allows // passing in a memdb transaction so it may be part of a larger txn. -func (s *StateStore) ensureNodeTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, - node *structs.Node) error { +func (s *StateStore) ensureNodeTxn(tx *memdb.Txn, idx uint64, node *structs.Node) error { // Check for an existing node existing, err := tx.First("nodes", "id", node.Node) if err != nil { @@ -154,7 +171,6 @@ func (s *StateStore) ensureNodeTxn(tx *memdb.Txn, idx uint64, watches *DumbWatch return fmt.Errorf("failed updating index: %s", err) } - watches.Arm("nodes") return nil } @@ -164,7 +180,7 @@ func (s *StateStore) GetNode(id string) (uint64, *structs.Node, error) { defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("GetNode")...) + idx := maxIndexTxn(tx, "nodes") // Retrieve the node from the state store node, err := tx.First("nodes", "id", id) @@ -178,18 +194,19 @@ func (s *StateStore) GetNode(id string) (uint64, *structs.Node, error) { } // Nodes is used to return all of the known nodes. -func (s *StateStore) Nodes() (uint64, structs.Nodes, error) { +func (s *StateStore) Nodes(ws memdb.WatchSet) (uint64, structs.Nodes, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("Nodes")...) + idx := maxIndexTxn(tx, "nodes") // Retrieve all of the nodes nodes, err := tx.Get("nodes", "id") if err != nil { return 0, nil, fmt.Errorf("failed nodes lookup: %s", err) } + ws.Add(nodes.WatchCh()) // Create and return the nodes list. var results structs.Nodes @@ -200,12 +217,12 @@ func (s *StateStore) Nodes() (uint64, structs.Nodes, error) { } // NodesByMeta is used to return all nodes with the given metadata key/value pairs. -func (s *StateStore) NodesByMeta(filters map[string]string) (uint64, structs.Nodes, error) { +func (s *StateStore) NodesByMeta(ws memdb.WatchSet, filters map[string]string) (uint64, structs.Nodes, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("Nodes")...) + idx := maxIndexTxn(tx, "nodes") // Retrieve all of the nodes var args []interface{} @@ -217,6 +234,7 @@ func (s *StateStore) NodesByMeta(filters map[string]string) (uint64, structs.Nod if err != nil { return 0, nil, fmt.Errorf("failed nodes lookup: %s", err) } + ws.Add(nodes.WatchCh()) // Create and return the nodes list. var results structs.Nodes @@ -255,10 +273,6 @@ func (s *StateStore) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeName string) e return nil } - // Use a watch manager since the inner functions can perform multiple - // ops per table. - watches := NewDumbWatchManager(s.tableWatches) - // Delete all services associated with the node and update the service index. services, err := tx.Get("services", "node", nodeName) if err != nil { @@ -271,7 +285,7 @@ func (s *StateStore) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeName string) e // Do the delete in a separate loop so we don't trash the iterator. for _, sid := range sids { - if err := s.deleteServiceTxn(tx, idx, watches, nodeName, sid); err != nil { + if err := s.deleteServiceTxn(tx, idx, nodeName, sid); err != nil { return err } } @@ -289,7 +303,7 @@ func (s *StateStore) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeName string) e // Do the delete in a separate loop so we don't trash the iterator. for _, cid := range cids { - if err := s.deleteCheckTxn(tx, idx, watches, nodeName, cid); err != nil { + if err := s.deleteCheckTxn(tx, idx, nodeName, cid); err != nil { return err } } @@ -306,7 +320,6 @@ func (s *StateStore) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeName string) e if err := tx.Insert("index", &IndexEntry{"coordinates", idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) } - watches.Arm("coordinates") } // Delete the node and update the index. @@ -329,13 +342,11 @@ func (s *StateStore) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeName string) e // Do the delete in a separate loop so we don't trash the iterator. for _, id := range ids { - if err := s.deleteSessionTxn(tx, idx, watches, id); err != nil { + if err := s.deleteSessionTxn(tx, idx, id); err != nil { return fmt.Errorf("failed session delete: %s", err) } } - watches.Arm("nodes") - tx.Defer(func() { watches.Notify() }) return nil } @@ -345,20 +356,17 @@ func (s *StateStore) EnsureService(idx uint64, node string, svc *structs.NodeSer defer tx.Abort() // Call the service registration upsert - watches := NewDumbWatchManager(s.tableWatches) - if err := s.ensureServiceTxn(tx, idx, watches, node, svc); err != nil { + if err := s.ensureServiceTxn(tx, idx, node, svc); err != nil { return err } - tx.Defer(func() { watches.Notify() }) tx.Commit() return nil } // ensureServiceTxn is used to upsert a service registration within an // existing memdb transaction. -func (s *StateStore) ensureServiceTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, - node string, svc *structs.NodeService) error { +func (s *StateStore) ensureServiceTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeService) error { // Check for existing service existing, err := tx.First("services", "id", node, svc.ID) if err != nil { @@ -394,23 +402,23 @@ func (s *StateStore) ensureServiceTxn(tx *memdb.Txn, idx uint64, watches *DumbWa return fmt.Errorf("failed updating index: %s", err) } - watches.Arm("services") return nil } // Services returns all services along with a list of associated tags. -func (s *StateStore) Services() (uint64, structs.Services, error) { +func (s *StateStore) Services(ws memdb.WatchSet) (uint64, structs.Services, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("Services")...) + idx := maxIndexTxn(tx, "services") // List all the services. services, err := tx.Get("services", "id") if err != nil { return 0, nil, fmt.Errorf("failed querying services: %s", err) } + ws.Add(services.WatchCh()) // Rip through the services and enumerate them and their unique set of // tags. @@ -439,12 +447,12 @@ func (s *StateStore) Services() (uint64, structs.Services, error) { } // ServicesByNodeMeta returns all services, filtered by the given node metadata. -func (s *StateStore) ServicesByNodeMeta(filters map[string]string) (uint64, structs.Services, error) { +func (s *StateStore) ServicesByNodeMeta(ws memdb.WatchSet, filters map[string]string) (uint64, structs.Services, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ServiceNodes")...) + idx := maxIndexTxn(tx, "services", "nodes") // Retrieve all of the nodes with the meta k/v pair var args []interface{} @@ -456,6 +464,15 @@ func (s *StateStore) ServicesByNodeMeta(filters map[string]string) (uint64, stru if err != nil { return 0, nil, fmt.Errorf("failed nodes lookup: %s", err) } + ws.Add(nodes.WatchCh()) + + // We don't want to track an unlimited number of services, so we pull a + // top-level watch to use as a fallback. + allServices, err := tx.Get("services", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed services lookup: %s", err) + } + allServicesCh := allServices.WatchCh() // Populate the services map unique := make(map[string]map[string]struct{}) @@ -464,11 +481,13 @@ func (s *StateStore) ServicesByNodeMeta(filters map[string]string) (uint64, stru if len(filters) > 1 && !structs.SatisfiesMetaFilters(n.Meta, filters) { continue } + // List all the services on the node services, err := tx.Get("services", "node", n.Node) if err != nil { return 0, nil, fmt.Errorf("failed querying services: %s", err) } + ws.AddWithLimit(watchLimit, services.WatchCh(), allServicesCh) // Rip through the services and enumerate them and their unique set of // tags. @@ -497,25 +516,27 @@ func (s *StateStore) ServicesByNodeMeta(filters map[string]string) (uint64, stru } // ServiceNodes returns the nodes associated with a given service name. -func (s *StateStore) ServiceNodes(serviceName string) (uint64, structs.ServiceNodes, error) { +func (s *StateStore) ServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.ServiceNodes, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ServiceNodes")...) + idx := maxIndexTxn(tx, "nodes", "services") // List all the services. services, err := tx.Get("services", "service", serviceName) if err != nil { return 0, nil, fmt.Errorf("failed service lookup: %s", err) } + ws.Add(services.WatchCh()) + var results structs.ServiceNodes for service := services.Next(); service != nil; service = services.Next() { results = append(results, service.(*structs.ServiceNode)) } - // Fill in the address details. - results, err = s.parseServiceNodes(tx, results) + // Fill in the node details. + results, err = s.parseServiceNodes(tx, ws, results) if err != nil { return 0, nil, fmt.Errorf("failed parsing service nodes: %s", err) } @@ -524,18 +545,19 @@ func (s *StateStore) ServiceNodes(serviceName string) (uint64, structs.ServiceNo // ServiceTagNodes returns the nodes associated with a given service, filtering // out services that don't contain the given tag. -func (s *StateStore) ServiceTagNodes(service, tag string) (uint64, structs.ServiceNodes, error) { +func (s *StateStore) ServiceTagNodes(ws memdb.WatchSet, service string, tag string) (uint64, structs.ServiceNodes, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ServiceNodes")...) + idx := maxIndexTxn(tx, "nodes", "services") // List all the services. services, err := tx.Get("services", "service", service) if err != nil { return 0, nil, fmt.Errorf("failed service lookup: %s", err) } + ws.Add(services.WatchCh()) // Gather all the services and apply the tag filter. var results structs.ServiceNodes @@ -546,8 +568,8 @@ func (s *StateStore) ServiceTagNodes(service, tag string) (uint64, structs.Servi } } - // Fill in the address details. - results, err = s.parseServiceNodes(tx, results) + // Fill in the node details. + results, err = s.parseServiceNodes(tx, ws, results) if err != nil { return 0, nil, fmt.Errorf("failed parsing service nodes: %s", err) } @@ -572,7 +594,16 @@ func serviceTagFilter(sn *structs.ServiceNode, tag string) bool { // parseServiceNodes iterates over a services query and fills in the node details, // returning a ServiceNodes slice. -func (s *StateStore) parseServiceNodes(tx *memdb.Txn, services structs.ServiceNodes) (structs.ServiceNodes, error) { +func (s *StateStore) parseServiceNodes(tx *memdb.Txn, ws memdb.WatchSet, services structs.ServiceNodes) (structs.ServiceNodes, error) { + // We don't want to track an unlimited number of nodes, so we pull a + // top-level watch to use as a fallback. + allNodes, err := tx.Get("nodes", "id") + if err != nil { + return nil, fmt.Errorf("failed nodes lookup: %s", err) + } + allNodesCh := allNodes.WatchCh() + + // Fill in the node data for each service instance. var results structs.ServiceNodes for _, sn := range services { // Note that we have to clone here because we don't want to @@ -581,10 +612,11 @@ func (s *StateStore) parseServiceNodes(tx *memdb.Txn, services structs.ServiceNo s := sn.PartialClone() // Grab the corresponding node record. - n, err := tx.First("nodes", "id", sn.Node) + watchCh, n, err := tx.FirstWatch("nodes", "id", sn.Node) if err != nil { return nil, fmt.Errorf("failed node lookup: %s", err) } + ws.AddWithLimit(watchLimit, watchCh, allNodesCh) // Populate the node-related fields. The tagged addresses may be // used by agents to perform address translation if they are @@ -607,7 +639,7 @@ func (s *StateStore) NodeService(nodeName string, serviceID string) (uint64, *st defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeService")...) + idx := maxIndexTxn(tx, "services") // Query the service service, err := tx.First("services", "id", nodeName, serviceID) @@ -623,18 +655,19 @@ func (s *StateStore) NodeService(nodeName string, serviceID string) (uint64, *st } // NodeServices is used to query service registrations by node ID. -func (s *StateStore) NodeServices(nodeName string) (uint64, *structs.NodeServices, error) { +func (s *StateStore) NodeServices(ws memdb.WatchSet, nodeName string) (uint64, *structs.NodeServices, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeServices")...) + idx := maxIndexTxn(tx, "nodes", "services") // Query the node - n, err := tx.First("nodes", "id", nodeName) + watchCh, n, err := tx.FirstWatch("nodes", "id", nodeName) if err != nil { return 0, nil, fmt.Errorf("node lookup failed: %s", err) } + ws.Add(watchCh) if n == nil { return 0, nil, nil } @@ -645,6 +678,7 @@ func (s *StateStore) NodeServices(nodeName string) (uint64, *structs.NodeService if err != nil { return 0, nil, fmt.Errorf("failed querying services for node %q: %s", nodeName, err) } + ws.Add(services.WatchCh()) // Initialize the node services struct ns := &structs.NodeServices{ @@ -667,19 +701,17 @@ func (s *StateStore) DeleteService(idx uint64, nodeName, serviceID string) error defer tx.Abort() // Call the service deletion - watches := NewDumbWatchManager(s.tableWatches) - if err := s.deleteServiceTxn(tx, idx, watches, nodeName, serviceID); err != nil { + if err := s.deleteServiceTxn(tx, idx, nodeName, serviceID); err != nil { return err } - tx.Defer(func() { watches.Notify() }) tx.Commit() return nil } // deleteServiceTxn is the inner method called to remove a service // registration within an existing transaction. -func (s *StateStore) deleteServiceTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, nodeName, serviceID string) error { +func (s *StateStore) deleteServiceTxn(tx *memdb.Txn, idx uint64, nodeName, serviceID string) error { // Look up the service. service, err := tx.First("services", "id", nodeName, serviceID) if err != nil { @@ -702,7 +734,7 @@ func (s *StateStore) deleteServiceTxn(tx *memdb.Txn, idx uint64, watches *DumbWa // Do the delete in a separate loop so we don't trash the iterator. for _, cid := range cids { - if err := s.deleteCheckTxn(tx, idx, watches, nodeName, cid); err != nil { + if err := s.deleteCheckTxn(tx, idx, nodeName, cid); err != nil { return err } } @@ -720,7 +752,6 @@ func (s *StateStore) deleteServiceTxn(tx *memdb.Txn, idx uint64, watches *DumbWa return fmt.Errorf("failed updating index: %s", err) } - watches.Arm("services") return nil } @@ -730,12 +761,10 @@ func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error { defer tx.Abort() // Call the check registration - watches := NewDumbWatchManager(s.tableWatches) - if err := s.ensureCheckTxn(tx, idx, watches, hc); err != nil { + if err := s.ensureCheckTxn(tx, idx, hc); err != nil { return err } - tx.Defer(func() { watches.Notify() }) tx.Commit() return nil } @@ -743,8 +772,7 @@ func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error { // ensureCheckTransaction is used as the inner method to handle inserting // a health check into the state store. It ensures safety against inserting // checks with no matching node or service. -func (s *StateStore) ensureCheckTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, - hc *structs.HealthCheck) error { +func (s *StateStore) ensureCheckTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthCheck) error { // Check if we have an existing health check existing, err := tx.First("checks", "id", hc.Node, string(hc.CheckID)) if err != nil { @@ -803,13 +831,11 @@ func (s *StateStore) ensureCheckTxn(tx *memdb.Txn, idx uint64, watches *DumbWatc // Delete the session in a separate loop so we don't trash the // iterator. - watches := NewDumbWatchManager(s.tableWatches) for _, id := range ids { - if err := s.deleteSessionTxn(tx, idx, watches, id); err != nil { + if err := s.deleteSessionTxn(tx, idx, id); err != nil { return fmt.Errorf("failed deleting session: %s", err) } } - tx.Defer(func() { watches.Notify() }) } // Persist the check registration in the db. @@ -820,7 +846,6 @@ func (s *StateStore) ensureCheckTxn(tx *memdb.Txn, idx uint64, watches *DumbWatc return fmt.Errorf("failed updating index: %s", err) } - watches.Arm("checks") return nil } @@ -831,13 +856,14 @@ func (s *StateStore) NodeCheck(nodeName string, checkID types.CheckID) (uint64, defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeCheck")...) + idx := maxIndexTxn(tx, "checks") // Return the check. check, err := tx.First("checks", "id", nodeName, string(checkID)) if err != nil { return 0, nil, fmt.Errorf("failed check lookup: %s", err) } + if check != nil { return idx, check.(*structs.HealthCheck), nil } else { @@ -847,115 +873,20 @@ func (s *StateStore) NodeCheck(nodeName string, checkID types.CheckID) (uint64, // NodeChecks is used to retrieve checks associated with the // given node from the state store. -func (s *StateStore) NodeChecks(nodeName string) (uint64, structs.HealthChecks, error) { +func (s *StateStore) NodeChecks(ws memdb.WatchSet, nodeName string) (uint64, structs.HealthChecks, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeChecks")...) + idx := maxIndexTxn(tx, "checks") // Return the checks. - checks, err := tx.Get("checks", "node", nodeName) + iter, err := tx.Get("checks", "node", nodeName) if err != nil { return 0, nil, fmt.Errorf("failed check lookup: %s", err) } - return s.parseChecks(idx, checks) -} + ws.Add(iter.WatchCh()) -// ServiceChecks is used to get all checks associated with a -// given service ID. The query is performed against a service -// _name_ instead of a service ID. -func (s *StateStore) ServiceChecks(serviceName string) (uint64, structs.HealthChecks, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ServiceChecks")...) - - // Return the checks. - checks, err := tx.Get("checks", "service", serviceName) - if err != nil { - return 0, nil, fmt.Errorf("failed check lookup: %s", err) - } - return s.parseChecks(idx, checks) -} - -// ServiceChecksByNodeMeta is used to get all checks associated with a -// given service ID, filtered by the given node metadata values. The query -// is performed against a service _name_ instead of a service ID. -func (s *StateStore) ServiceChecksByNodeMeta(serviceName string, filters map[string]string) (uint64, structs.HealthChecks, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ServiceChecksByNodeMeta")...) - - // Return the checks. - checks, err := tx.Get("checks", "service", serviceName) - if err != nil { - return 0, nil, fmt.Errorf("failed check lookup: %s", err) - } - return s.parseChecksByNodeMeta(idx, checks, tx, filters) -} - -// ChecksInState is used to query the state store for all checks -// which are in the provided state. -func (s *StateStore) ChecksInState(state string) (uint64, structs.HealthChecks, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ChecksInState")...) - - // Query all checks if HealthAny is passed - if state == structs.HealthAny { - checks, err := tx.Get("checks", "status") - if err != nil { - return 0, nil, fmt.Errorf("failed check lookup: %s", err) - } - return s.parseChecks(idx, checks) - } - - // Any other state we need to query for explicitly - checks, err := tx.Get("checks", "status", state) - if err != nil { - return 0, nil, fmt.Errorf("failed check lookup: %s", err) - } - return s.parseChecks(idx, checks) -} - -// ChecksInStateByNodeMeta is used to query the state store for all checks -// which are in the provided state, filtered by the given node metadata values. -func (s *StateStore) ChecksInStateByNodeMeta(state string, filters map[string]string) (uint64, structs.HealthChecks, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ChecksInStateByNodeMeta")...) - - // Query all checks if HealthAny is passed - var checks memdb.ResultIterator - var err error - if state == structs.HealthAny { - checks, err = tx.Get("checks", "status") - if err != nil { - return 0, nil, fmt.Errorf("failed check lookup: %s", err) - } - } else { - // Any other state we need to query for explicitly - checks, err = tx.Get("checks", "status", state) - if err != nil { - return 0, nil, fmt.Errorf("failed check lookup: %s", err) - } - } - - return s.parseChecksByNodeMeta(idx, checks, tx, filters) -} - -// parseChecks is a helper function used to deduplicate some -// repetitive code for returning health checks. -func (s *StateStore) parseChecks(idx uint64, iter memdb.ResultIterator) (uint64, structs.HealthChecks, error) { - // Gather the health checks and return them properly type casted. var results structs.HealthChecks for check := iter.Next(); check != nil; check = iter.Next() { results = append(results, check.(*structs.HealthCheck)) @@ -963,20 +894,140 @@ func (s *StateStore) parseChecks(idx uint64, iter memdb.ResultIterator) (uint64, return idx, results, nil } +// ServiceChecks is used to get all checks associated with a +// given service ID. The query is performed against a service +// _name_ instead of a service ID. +func (s *StateStore) ServiceChecks(ws memdb.WatchSet, serviceName string) (uint64, structs.HealthChecks, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, "checks") + + // Return the checks. + iter, err := tx.Get("checks", "service", serviceName) + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + ws.Add(iter.WatchCh()) + + var results structs.HealthChecks + for check := iter.Next(); check != nil; check = iter.Next() { + results = append(results, check.(*structs.HealthCheck)) + } + return idx, results, nil +} + +// ServiceChecksByNodeMeta is used to get all checks associated with a +// given service ID, filtered by the given node metadata values. The query +// is performed against a service _name_ instead of a service ID. +func (s *StateStore) ServiceChecksByNodeMeta(ws memdb.WatchSet, serviceName string, + filters map[string]string) (uint64, structs.HealthChecks, error) { + + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, "nodes", "checks") + + // Return the checks. + iter, err := tx.Get("checks", "service", serviceName) + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + ws.Add(iter.WatchCh()) + + return s.parseChecksByNodeMeta(tx, ws, idx, iter, filters) +} + +// ChecksInState is used to query the state store for all checks +// which are in the provided state. +func (s *StateStore) ChecksInState(ws memdb.WatchSet, state string) (uint64, structs.HealthChecks, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, "checks") + + // Query all checks if HealthAny is passed, otherwise use the index. + var iter memdb.ResultIterator + var err error + if state == structs.HealthAny { + iter, err = tx.Get("checks", "status") + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + } else { + iter, err = tx.Get("checks", "status", state) + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + } + ws.Add(iter.WatchCh()) + + var results structs.HealthChecks + for check := iter.Next(); check != nil; check = iter.Next() { + results = append(results, check.(*structs.HealthCheck)) + } + return idx, results, nil +} + +// ChecksInStateByNodeMeta is used to query the state store for all checks +// which are in the provided state, filtered by the given node metadata values. +func (s *StateStore) ChecksInStateByNodeMeta(ws memdb.WatchSet, state string, filters map[string]string) (uint64, structs.HealthChecks, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, "nodes", "checks") + + // Query all checks if HealthAny is passed, otherwise use the index. + var iter memdb.ResultIterator + var err error + if state == structs.HealthAny { + iter, err = tx.Get("checks", "status") + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + } else { + iter, err = tx.Get("checks", "status", state) + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + } + ws.Add(iter.WatchCh()) + + return s.parseChecksByNodeMeta(tx, ws, idx, iter, filters) +} + // parseChecksByNodeMeta is a helper function used to deduplicate some // repetitive code for returning health checks filtered by node metadata fields. -func (s *StateStore) parseChecksByNodeMeta(idx uint64, iter memdb.ResultIterator, tx *memdb.Txn, - filters map[string]string) (uint64, structs.HealthChecks, error) { +func (s *StateStore) parseChecksByNodeMeta(tx *memdb.Txn, ws memdb.WatchSet, + idx uint64, iter memdb.ResultIterator, filters map[string]string) (uint64, structs.HealthChecks, error) { + + // We don't want to track an unlimited number of nodes, so we pull a + // top-level watch to use as a fallback. + allNodes, err := tx.Get("nodes", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed nodes lookup: %s", err) + } + allNodesCh := allNodes.WatchCh() + + // Only take results for nodes that satisfy the node metadata filters. var results structs.HealthChecks for check := iter.Next(); check != nil; check = iter.Next() { healthCheck := check.(*structs.HealthCheck) - node, err := tx.First("nodes", "id", healthCheck.Node) + watchCh, node, err := tx.FirstWatch("nodes", "id", healthCheck.Node) if err != nil { return 0, nil, fmt.Errorf("failed node lookup: %s", err) } if node == nil { return 0, nil, ErrMissingNode } + + // Add even the filtered nodes so we wake up if the node metadata + // changes. + ws.AddWithLimit(watchLimit, watchCh, allNodesCh) if structs.SatisfiesMetaFilters(node.(*structs.Node).Meta, filters) { results = append(results, healthCheck) } @@ -990,19 +1041,17 @@ func (s *StateStore) DeleteCheck(idx uint64, node string, checkID types.CheckID) defer tx.Abort() // Call the check deletion - watches := NewDumbWatchManager(s.tableWatches) - if err := s.deleteCheckTxn(tx, idx, watches, node, checkID); err != nil { + if err := s.deleteCheckTxn(tx, idx, node, checkID); err != nil { return err } - tx.Defer(func() { watches.Notify() }) tx.Commit() return nil } // deleteCheckTxn is the inner method used to call a health // check deletion within an existing transaction. -func (s *StateStore) deleteCheckTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, node string, checkID types.CheckID) error { +func (s *StateStore) deleteCheckTxn(tx *memdb.Txn, idx uint64, node string, checkID types.CheckID) error { // Try to retrieve the existing health check. hc, err := tx.First("checks", "id", node, string(checkID)) if err != nil { @@ -1032,73 +1081,70 @@ func (s *StateStore) deleteCheckTxn(tx *memdb.Txn, idx uint64, watches *DumbWatc // Do the delete in a separate loop so we don't trash the iterator. for _, id := range ids { - if err := s.deleteSessionTxn(tx, idx, watches, id); err != nil { + if err := s.deleteSessionTxn(tx, idx, id); err != nil { return fmt.Errorf("failed deleting session: %s", err) } } - watches.Arm("checks") return nil } -// CheckServiceNodes is used to query all nodes and checks for a given service -// The results are compounded into a CheckServiceNodes, and the index returned -// is the maximum index observed over any node, check, or service in the result -// set. -func (s *StateStore) CheckServiceNodes(serviceName string) (uint64, structs.CheckServiceNodes, error) { +// CheckServiceNodes is used to query all nodes and checks for a given service. +func (s *StateStore) CheckServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.CheckServiceNodes, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("CheckServiceNodes")...) + idx := maxIndexTxn(tx, "nodes", "services", "checks") // Query the state store for the service. - services, err := tx.Get("services", "service", serviceName) + iter, err := tx.Get("services", "service", serviceName) if err != nil { return 0, nil, fmt.Errorf("failed service lookup: %s", err) } + ws.Add(iter.WatchCh()) // Return the results. var results structs.ServiceNodes - for service := services.Next(); service != nil; service = services.Next() { + for service := iter.Next(); service != nil; service = iter.Next() { results = append(results, service.(*structs.ServiceNode)) } - return s.parseCheckServiceNodes(tx, idx, results, err) + return s.parseCheckServiceNodes(tx, ws, idx, serviceName, results, err) } // CheckServiceTagNodes is used to query all nodes and checks for a given -// service, filtering out services that don't contain the given tag. The results -// are compounded into a CheckServiceNodes, and the index returned is the maximum -// index observed over any node, check, or service in the result set. -func (s *StateStore) CheckServiceTagNodes(serviceName, tag string) (uint64, structs.CheckServiceNodes, error) { +// service, filtering out services that don't contain the given tag. +func (s *StateStore) CheckServiceTagNodes(ws memdb.WatchSet, serviceName, tag string) (uint64, structs.CheckServiceNodes, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("CheckServiceNodes")...) + idx := maxIndexTxn(tx, "nodes", "services", "checks") // Query the state store for the service. - services, err := tx.Get("services", "service", serviceName) + iter, err := tx.Get("services", "service", serviceName) if err != nil { return 0, nil, fmt.Errorf("failed service lookup: %s", err) } + ws.Add(iter.WatchCh()) // Return the results, filtering by tag. var results structs.ServiceNodes - for service := services.Next(); service != nil; service = services.Next() { + for service := iter.Next(); service != nil; service = iter.Next() { svc := service.(*structs.ServiceNode) if !serviceTagFilter(svc, tag) { results = append(results, svc) } } - return s.parseCheckServiceNodes(tx, idx, results, err) + return s.parseCheckServiceNodes(tx, ws, idx, serviceName, results, err) } // parseCheckServiceNodes is used to parse through a given set of services, // and query for an associated node and a set of checks. This is the inner // method used to return a rich set of results from a more simple query. func (s *StateStore) parseCheckServiceNodes( - tx *memdb.Txn, idx uint64, services structs.ServiceNodes, + tx *memdb.Txn, ws memdb.WatchSet, idx uint64, + serviceName string, services structs.ServiceNodes, err error) (uint64, structs.CheckServiceNodes, error) { if err != nil { return 0, nil, err @@ -1110,32 +1156,57 @@ func (s *StateStore) parseCheckServiceNodes( return idx, nil, nil } + // We don't want to track an unlimited number of nodes, so we pull a + // top-level watch to use as a fallback. + allNodes, err := tx.Get("nodes", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed nodes lookup: %s", err) + } + allNodesCh := allNodes.WatchCh() + + // We need a similar fallback for checks. Since services need the + // status of node + service-specific checks, we pull in a top-level + // watch over all checks. + allChecks, err := tx.Get("checks", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed checks lookup: %s", err) + } + allChecksCh := allChecks.WatchCh() + results := make(structs.CheckServiceNodes, 0, len(services)) for _, sn := range services { // Retrieve the node. - n, err := tx.First("nodes", "id", sn.Node) + watchCh, n, err := tx.FirstWatch("nodes", "id", sn.Node) if err != nil { return 0, nil, fmt.Errorf("failed node lookup: %s", err) } + ws.AddWithLimit(watchLimit, watchCh, allNodesCh) + if n == nil { return 0, nil, ErrMissingNode } node := n.(*structs.Node) - // We need to return the checks specific to the given service - // as well as the node itself. Unfortunately, memdb won't let - // us use the index to do the latter query so we have to pull - // them all and filter. + // First add the node-level checks. These always apply to any + // service on the node. var checks structs.HealthChecks - iter, err := tx.Get("checks", "node", sn.Node) + iter, err := tx.Get("checks", "node_service_check", sn.Node, false) if err != nil { return 0, nil, err } + ws.AddWithLimit(watchLimit, iter.WatchCh(), allChecksCh) for check := iter.Next(); check != nil; check = iter.Next() { - hc := check.(*structs.HealthCheck) - if hc.ServiceID == "" || hc.ServiceID == sn.ServiceID { - checks = append(checks, hc) - } + checks = append(checks, check.(*structs.HealthCheck)) + } + + // Now add the service-specific checks. + iter, err = tx.Get("checks", "node_service", sn.Node, sn.ServiceID) + if err != nil { + return 0, nil, err + } + ws.AddWithLimit(watchLimit, iter.WatchCh(), allChecksCh) + for check := iter.Next(); check != nil; check = iter.Next() { + checks = append(checks, check.(*structs.HealthCheck)) } // Append to the results. @@ -1151,45 +1222,62 @@ func (s *StateStore) parseCheckServiceNodes( // NodeInfo is used to generate a dump of a single node. The dump includes // all services and checks which are registered against the node. -func (s *StateStore) NodeInfo(node string) (uint64, structs.NodeDump, error) { +func (s *StateStore) NodeInfo(ws memdb.WatchSet, node string) (uint64, structs.NodeDump, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeInfo")...) + idx := maxIndexTxn(tx, "nodes", "services", "checks") // Query the node by the passed node nodes, err := tx.Get("nodes", "id", node) if err != nil { return 0, nil, fmt.Errorf("failed node lookup: %s", err) } - return s.parseNodes(tx, idx, nodes) + ws.Add(nodes.WatchCh()) + return s.parseNodes(tx, ws, idx, nodes) } // NodeDump is used to generate a dump of all nodes. This call is expensive // as it has to query every node, service, and check. The response can also // be quite large since there is currently no filtering applied. -func (s *StateStore) NodeDump() (uint64, structs.NodeDump, error) { +func (s *StateStore) NodeDump(ws memdb.WatchSet) (uint64, structs.NodeDump, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeDump")...) + idx := maxIndexTxn(tx, "nodes", "services", "checks") // Fetch all of the registered nodes nodes, err := tx.Get("nodes", "id") if err != nil { return 0, nil, fmt.Errorf("failed node lookup: %s", err) } - return s.parseNodes(tx, idx, nodes) + ws.Add(nodes.WatchCh()) + return s.parseNodes(tx, ws, idx, nodes) } // parseNodes takes an iterator over a set of nodes and returns a struct // containing the nodes along with all of their associated services // and/or health checks. -func (s *StateStore) parseNodes(tx *memdb.Txn, idx uint64, +func (s *StateStore) parseNodes(tx *memdb.Txn, ws memdb.WatchSet, idx uint64, iter memdb.ResultIterator) (uint64, structs.NodeDump, error) { + // We don't want to track an unlimited number of services, so we pull a + // top-level watch to use as a fallback. + allServices, err := tx.Get("services", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed services lookup: %s", err) + } + allServicesCh := allServices.WatchCh() + + // We need a similar fallback for checks. + allChecks, err := tx.Get("checks", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed checks lookup: %s", err) + } + allChecksCh := allChecks.WatchCh() + var results structs.NodeDump for n := iter.Next(); n != nil; n = iter.Next() { node := n.(*structs.Node) @@ -1208,6 +1296,7 @@ func (s *StateStore) parseNodes(tx *memdb.Txn, idx uint64, if err != nil { return 0, nil, fmt.Errorf("failed services lookup: %s", err) } + ws.AddWithLimit(watchLimit, services.WatchCh(), allServicesCh) for service := services.Next(); service != nil; service = services.Next() { ns := service.(*structs.ServiceNode).ToNodeService() dump.Services = append(dump.Services, ns) @@ -1218,6 +1307,7 @@ func (s *StateStore) parseNodes(tx *memdb.Txn, idx uint64, if err != nil { return 0, nil, fmt.Errorf("failed node lookup: %s", err) } + ws.AddWithLimit(watchLimit, checks.WatchCh(), allChecksCh) for check := checks.Next(); check != nil; check = checks.Next() { hc := check.(*structs.HealthCheck) dump.Checks = append(dump.Checks, hc) diff --git a/consul/state/catalog_test.go b/consul/state/catalog_test.go index 71d1e45143..eaecfcef71 100644 --- a/consul/state/catalog_test.go +++ b/consul/state/catalog_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/types" + "github.com/hashicorp/go-memdb" ) func TestStateStore_EnsureRegistration(t *testing.T) { @@ -31,7 +32,7 @@ func TestStateStore_EnsureRegistration(t *testing.T) { } // Retrieve the node and verify its contents. - verifyNode := func(created, modified uint64) { + verifyNode := func() { _, out, err := s.GetNode("node1") if err != nil { t.Fatalf("err: %s", err) @@ -41,11 +42,11 @@ func TestStateStore_EnsureRegistration(t *testing.T) { len(out.TaggedAddresses) != 1 || out.TaggedAddresses["hello"] != "world" || out.Meta["somekey"] != "somevalue" || - out.CreateIndex != created || out.ModifyIndex != modified { + out.CreateIndex != 1 || out.ModifyIndex != 1 { t.Fatalf("bad node returned: %#v", out) } } - verifyNode(1, 1) + verifyNode() // Add in a service definition. req.Service = &structs.NodeService{ @@ -59,12 +60,12 @@ func TestStateStore_EnsureRegistration(t *testing.T) { } // Verify that the service got registered. - verifyService := func(created, modified uint64) { - idx, out, err := s.NodeServices("node1") + verifyService := func() { + idx, out, err := s.NodeServices(nil, "node1") if err != nil { t.Fatalf("err: %s", err) } - if idx != modified { + if idx != 2 { t.Fatalf("bad index: %d", idx) } if len(out.Services) != 1 { @@ -73,7 +74,7 @@ func TestStateStore_EnsureRegistration(t *testing.T) { r := out.Services["redis1"] if r == nil || r.ID != "redis1" || r.Service != "redis" || r.Address != "1.1.1.1" || r.Port != 8080 || - r.CreateIndex != created || r.ModifyIndex != modified { + r.CreateIndex != 2 || r.ModifyIndex != 2 { t.Fatalf("bad service returned: %#v", r) } @@ -81,17 +82,17 @@ func TestStateStore_EnsureRegistration(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - if idx != modified { + if idx != 2 { t.Fatalf("bad index: %d", idx) } if r == nil || r.ID != "redis1" || r.Service != "redis" || r.Address != "1.1.1.1" || r.Port != 8080 || - r.CreateIndex != created || r.ModifyIndex != modified { + r.CreateIndex != 2 || r.ModifyIndex != 2 { t.Fatalf("bad service returned: %#v", r) } } - verifyNode(1, 2) - verifyService(2, 2) + verifyNode() + verifyService() // Add in a top-level check. req.Check = &structs.HealthCheck{ @@ -104,12 +105,12 @@ func TestStateStore_EnsureRegistration(t *testing.T) { } // Verify that the check got registered. - verifyCheck := func(created, modified uint64) { - idx, out, err := s.NodeChecks("node1") + verifyCheck := func() { + idx, out, err := s.NodeChecks(nil, "node1") if err != nil { t.Fatalf("err: %s", err) } - if idx != modified { + if idx != 3 { t.Fatalf("bad index: %d", idx) } if len(out) != 1 { @@ -117,7 +118,7 @@ func TestStateStore_EnsureRegistration(t *testing.T) { } c := out[0] if c.Node != "node1" || c.CheckID != "check1" || c.Name != "check" || - c.CreateIndex != created || c.ModifyIndex != modified { + c.CreateIndex != 3 || c.ModifyIndex != 3 { t.Fatalf("bad check returned: %#v", c) } @@ -125,17 +126,17 @@ func TestStateStore_EnsureRegistration(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - if idx != modified { + if idx != 3 { t.Fatalf("bad index: %d", idx) } if c.Node != "node1" || c.CheckID != "check1" || c.Name != "check" || - c.CreateIndex != created || c.ModifyIndex != modified { + c.CreateIndex != 3 || c.ModifyIndex != 3 { t.Fatalf("bad check returned: %#v", c) } } - verifyNode(1, 3) - verifyService(2, 3) - verifyCheck(3, 3) + verifyNode() + verifyService() + verifyCheck() // Add in another check via the slice. req.Checks = structs.HealthChecks{ @@ -150,10 +151,10 @@ func TestStateStore_EnsureRegistration(t *testing.T) { } // Verify that the additional check got registered. - verifyNode(1, 4) - verifyService(2, 4) - func() { - idx, out, err := s.NodeChecks("node1") + verifyNode() + verifyService() + { + idx, out, err := s.NodeChecks(nil, "node1") if err != nil { t.Fatalf("err: %s", err) } @@ -174,7 +175,7 @@ func TestStateStore_EnsureRegistration(t *testing.T) { c2.CreateIndex != 4 || c2.ModifyIndex != 4 { t.Fatalf("bad check returned: %#v", c2) } - }() + } } func TestStateStore_EnsureRegistration_Restore(t *testing.T) { @@ -192,17 +193,17 @@ func TestStateStore_EnsureRegistration_Restore(t *testing.T) { restore.Commit() // Retrieve the node and verify its contents. - verifyNode := func(created, modified uint64) { + verifyNode := func() { _, out, err := s.GetNode("node1") if err != nil { t.Fatalf("err: %s", err) } if out.Node != "node1" || out.Address != "1.2.3.4" || - out.CreateIndex != created || out.ModifyIndex != modified { + out.CreateIndex != 1 || out.ModifyIndex != 1 { t.Fatalf("bad node returned: %#v", out) } } - verifyNode(1, 1) + verifyNode() // Add in a service definition. req.Service = &structs.NodeService{ @@ -218,12 +219,12 @@ func TestStateStore_EnsureRegistration_Restore(t *testing.T) { restore.Commit() // Verify that the service got registered. - verifyService := func(created, modified uint64) { - idx, out, err := s.NodeServices("node1") + verifyService := func() { + idx, out, err := s.NodeServices(nil, "node1") if err != nil { t.Fatalf("err: %s", err) } - if idx != modified { + if idx != 2 { t.Fatalf("bad index: %d", idx) } if len(out.Services) != 1 { @@ -232,12 +233,10 @@ func TestStateStore_EnsureRegistration_Restore(t *testing.T) { s := out.Services["redis1"] if s.ID != "redis1" || s.Service != "redis" || s.Address != "1.1.1.1" || s.Port != 8080 || - s.CreateIndex != created || s.ModifyIndex != modified { + s.CreateIndex != 2 || s.ModifyIndex != 2 { t.Fatalf("bad service returned: %#v", s) } } - verifyNode(1, 2) - verifyService(2, 2) // Add in a top-level check. req.Check = &structs.HealthCheck{ @@ -252,12 +251,12 @@ func TestStateStore_EnsureRegistration_Restore(t *testing.T) { restore.Commit() // Verify that the check got registered. - verifyCheck := func(created, modified uint64) { - idx, out, err := s.NodeChecks("node1") + verifyCheck := func() { + idx, out, err := s.NodeChecks(nil, "node1") if err != nil { t.Fatalf("err: %s", err) } - if idx != modified { + if idx != 3 { t.Fatalf("bad index: %d", idx) } if len(out) != 1 { @@ -265,13 +264,13 @@ func TestStateStore_EnsureRegistration_Restore(t *testing.T) { } c := out[0] if c.Node != "node1" || c.CheckID != "check1" || c.Name != "check" || - c.CreateIndex != created || c.ModifyIndex != modified { + c.CreateIndex != 3 || c.ModifyIndex != 3 { t.Fatalf("bad check returned: %#v", c) } } - verifyNode(1, 3) - verifyService(2, 3) - verifyCheck(3, 3) + verifyNode() + verifyService() + verifyCheck() // Add in another check via the slice. req.Checks = structs.HealthChecks{ @@ -288,10 +287,10 @@ func TestStateStore_EnsureRegistration_Restore(t *testing.T) { restore.Commit() // Verify that the additional check got registered. - verifyNode(1, 4) - verifyService(2, 4) + verifyNode() + verifyService() func() { - idx, out, err := s.NodeChecks("node1") + idx, out, err := s.NodeChecks(nil, "node1") if err != nil { t.Fatalf("err: %s", err) } @@ -315,94 +314,6 @@ func TestStateStore_EnsureRegistration_Restore(t *testing.T) { }() } -func TestStateStore_EnsureRegistration_Watches(t *testing.T) { - s := testStateStore(t) - - req := &structs.RegisterRequest{ - Node: "node1", - Address: "1.2.3.4", - } - - // The nodes watch should fire for this one. - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyNoWatch(t, s.getTableWatch("services"), func() { - verifyNoWatch(t, s.getTableWatch("checks"), func() { - if err := s.EnsureRegistration(1, req); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - // The nodes watch should fire for this one. - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyNoWatch(t, s.getTableWatch("services"), func() { - verifyNoWatch(t, s.getTableWatch("checks"), func() { - restore := s.Restore() - if err := restore.Registration(1, req); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - }) - }) - }) - - // With a service definition added it should fire nodes and - // services. - req.Service = &structs.NodeService{ - ID: "redis1", - Service: "redis", - Address: "1.1.1.1", - Port: 8080, - } - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.getTableWatch("services"), func() { - verifyNoWatch(t, s.getTableWatch("checks"), func() { - if err := s.EnsureRegistration(2, req); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.getTableWatch("services"), func() { - verifyNoWatch(t, s.getTableWatch("checks"), func() { - restore := s.Restore() - if err := restore.Registration(2, req); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - }) - }) - }) - - // Now with a check it should hit all three. - req.Check = &structs.HealthCheck{ - Node: "node1", - CheckID: "check1", - Name: "check", - } - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.getTableWatch("services"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - if err := s.EnsureRegistration(3, req); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.getTableWatch("services"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - restore := s.Restore() - if err := restore.Registration(3, req); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - }) - }) - }) -} - func TestStateStore_EnsureNode(t *testing.T) { s := testStateStore(t) @@ -480,34 +391,39 @@ func TestStateStore_EnsureNode(t *testing.T) { func TestStateStore_GetNodes(t *testing.T) { s := testStateStore(t) - // Listing with no results returns nil - idx, res, err := s.Nodes() + // Listing with no results returns nil. + ws := memdb.NewWatchSet() + idx, res, err := s.Nodes(ws) if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } - // Create some nodes in the state store + // Create some nodes in the state store. testRegisterNode(t, s, 0, "node0") testRegisterNode(t, s, 1, "node1") testRegisterNode(t, s, 2, "node2") + if !watchFired(ws) { + t.Fatalf("bad") + } - // Retrieve the nodes - idx, nodes, err := s.Nodes() + // Retrieve the nodes. + ws = memdb.NewWatchSet() + idx, nodes, err := s.Nodes(ws) if err != nil { t.Fatalf("err: %s", err) } - // Highest index was returned + // Highest index was returned. if idx != 2 { t.Fatalf("bad index: %d", idx) } - // All nodes were returned + // All nodes were returned. if n := len(nodes); n != 3 { t.Fatalf("bad node count: %d", n) } - // Make sure the nodes match + // Make sure the nodes match. for i, node := range nodes { if node.CreateIndex != uint64(i) || node.ModifyIndex != uint64(i) { t.Fatalf("bad node index: %d, %d", node.CreateIndex, node.ModifyIndex) @@ -517,6 +433,17 @@ func TestStateStore_GetNodes(t *testing.T) { t.Fatalf("bad: %#v", node) } } + + // Make sure a node delete fires the watch. + if watchFired(ws) { + t.Fatalf("bad") + } + if err := s.DeleteNode(3, "node1"); err != nil { + t.Fatalf("err: %s", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } } func BenchmarkGetNodes(b *testing.B) { @@ -532,8 +459,9 @@ func BenchmarkGetNodes(b *testing.B) { b.Fatalf("err: %v", err) } + ws := memdb.NewWatchSet() for i := 0; i < b.N; i++ { - s.Nodes() + s.Nodes(ws) } } @@ -541,15 +469,19 @@ func TestStateStore_GetNodesByMeta(t *testing.T) { s := testStateStore(t) // Listing with no results returns nil - idx, res, err := s.NodesByMeta(map[string]string{"somekey": "somevalue"}) + ws := memdb.NewWatchSet() + idx, res, err := s.NodesByMeta(ws, map[string]string{"somekey": "somevalue"}) if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } - // Create some nodes in the state store + // Create some nodes in the state store. testRegisterNodeWithMeta(t, s, 0, "node0", map[string]string{"role": "client"}) testRegisterNodeWithMeta(t, s, 1, "node1", map[string]string{"role": "client", "common": "1"}) testRegisterNodeWithMeta(t, s, 2, "node2", map[string]string{"role": "server", "common": "1"}) + if !watchFired(ws) { + t.Fatalf("bad") + } cases := []struct { filters map[string]string @@ -578,7 +510,7 @@ func TestStateStore_GetNodesByMeta(t *testing.T) { } for _, tc := range cases { - _, result, err := s.NodesByMeta(tc.filters) + _, result, err := s.NodesByMeta(nil, tc.filters) if err != nil { t.Fatalf("bad: %v", err) } @@ -593,23 +525,24 @@ func TestStateStore_GetNodesByMeta(t *testing.T) { } } } -} -func BenchmarkGetNodesByMeta(b *testing.B) { - s, err := NewStateStore(nil) + // Set up a watch. + ws = memdb.NewWatchSet() + _, _, err = s.NodesByMeta(ws, map[string]string{"role": "client"}) if err != nil { - b.Fatalf("err: %s", err) + t.Fatalf("err: %v", err) } - if err := s.EnsureNode(100, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - b.Fatalf("err: %v", err) - } - if err := s.EnsureNode(101, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { - b.Fatalf("err: %v", err) + // Make an unrelated modification and make sure the watch doesn't fire. + testRegisterNodeWithMeta(t, s, 3, "node3", map[string]string{"foo": "bar"}) + if watchFired(ws) { + t.Fatalf("bad") } - for i := 0; i < b.N; i++ { - s.Nodes() + // Change a watched key and make sure it fires. + testRegisterNodeWithMeta(t, s, 4, "node0", map[string]string{"role": "different"}) + if !watchFired(ws) { + t.Fatalf("bad") } } @@ -710,68 +643,17 @@ func TestStateStore_Node_Snapshot(t *testing.T) { } } -func TestStateStore_Node_Watches(t *testing.T) { - s := testStateStore(t) - - // Call functions that update the nodes table and make sure a watch fires - // each time. - verifyWatch(t, s.getTableWatch("nodes"), func() { - req := &structs.RegisterRequest{ - Node: "node1", - } - if err := s.EnsureRegistration(1, req); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("nodes"), func() { - node := &structs.Node{Node: "node2"} - if err := s.EnsureNode(2, node); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("nodes"), func() { - if err := s.DeleteNode(3, "node2"); err != nil { - t.Fatalf("err: %s", err) - } - }) - - // Check that a delete of a node + service + check + coordinate triggers - // all tables in one shot. - testRegisterNode(t, s, 4, "node1") - testRegisterService(t, s, 5, "node1", "service1") - testRegisterCheck(t, s, 6, "node1", "service1", "check3", structs.HealthPassing) - updates := structs.Coordinates{ - &structs.Coordinate{ - Node: "node1", - Coord: generateRandomCoordinate(), - }, - } - if err := s.CoordinateBatchUpdate(7, updates); err != nil { - t.Fatalf("err: %s", err) - } - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.getTableWatch("services"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - verifyWatch(t, s.getTableWatch("coordinates"), func() { - if err := s.DeleteNode(7, "node1"); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - }) -} - func TestStateStore_EnsureService(t *testing.T) { s := testStateStore(t) - // Fetching services for a node with none returns nil - idx, res, err := s.NodeServices("node1") + // Fetching services for a node with none returns nil. + ws := memdb.NewWatchSet() + idx, res, err := s.NodeServices(ws, "node1") if err != nil || res != nil || idx != 0 { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } - // Create the service registration + // Create the service registration. ns1 := &structs.NodeService{ ID: "service1", Service: "redis", @@ -780,21 +662,35 @@ func TestStateStore_EnsureService(t *testing.T) { Port: 1111, } - // Creating a service without a node returns an error + // Creating a service without a node returns an error. if err := s.EnsureService(1, "node1", ns1); err != ErrMissingNode { t.Fatalf("expected %#v, got: %#v", ErrMissingNode, err) } + if watchFired(ws) { + t.Fatalf("bad") + } - // Register the nodes + // Register the nodes. testRegisterNode(t, s, 0, "node1") testRegisterNode(t, s, 1, "node2") + if !watchFired(ws) { + t.Fatalf("bad") + } - // Service successfully registers into the state store + // Service successfully registers into the state store. + ws = memdb.NewWatchSet() + _, _, err = s.NodeServices(ws, "node1") + if err != nil { + t.Fatalf("err: %v", err) + } if err = s.EnsureService(10, "node1", ns1); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } - // Register a similar service against both nodes + // Register a similar service against both nodes. ns2 := *ns1 ns2.ID = "service2" for _, n := range []string{"node1", "node2"} { @@ -803,15 +699,24 @@ func TestStateStore_EnsureService(t *testing.T) { } } - // Register a different service on the bad node + // Register a different service on the bad node. + ws = memdb.NewWatchSet() + _, _, err = s.NodeServices(ws, "node1") + if err != nil { + t.Fatalf("err: %v", err) + } ns3 := *ns1 ns3.ID = "service3" if err := s.EnsureService(30, "node2", &ns3); err != nil { t.Fatalf("err: %s", err) } + if watchFired(ws) { + t.Fatalf("bad") + } - // Retrieve the services - idx, out, err := s.NodeServices("node1") + // Retrieve the services. + ws = memdb.NewWatchSet() + idx, out, err := s.NodeServices(ws, "node1") if err != nil { t.Fatalf("err: %s", err) } @@ -819,12 +724,12 @@ func TestStateStore_EnsureService(t *testing.T) { t.Fatalf("bad index: %d", idx) } - // Only the services for the requested node are returned + // Only the services for the requested node are returned. if out == nil || len(out.Services) != 2 { t.Fatalf("bad services: %#v", out) } - // Results match the inserted services and have the proper indexes set + // Results match the inserted services and have the proper indexes set. expect1 := *ns1 expect1.CreateIndex, expect1.ModifyIndex = 10, 10 if svc := out.Services["service1"]; !reflect.DeepEqual(&expect1, svc) { @@ -837,19 +742,22 @@ func TestStateStore_EnsureService(t *testing.T) { t.Fatalf("bad: %#v %#v", ns2, svc) } - // Index tables were updated + // Index tables were updated. if idx := s.maxIndex("services"); idx != 30 { t.Fatalf("bad index: %d", idx) } - // Update a service registration + // Update a service registration. ns1.Address = "1.1.1.2" if err := s.EnsureService(40, "node1", ns1); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } - // Retrieve the service again and ensure it matches - idx, out, err = s.NodeServices("node1") + // Retrieve the service again and ensure it matches.. + idx, out, err = s.NodeServices(nil, "node1") if err != nil { t.Fatalf("err: %s", err) } @@ -865,7 +773,7 @@ func TestStateStore_EnsureService(t *testing.T) { t.Fatalf("bad: %#v", svc) } - // Index tables were updated + // Index tables were updated. if idx := s.maxIndex("services"); idx != 40 { t.Fatalf("bad index: %d", idx) } @@ -874,6 +782,19 @@ func TestStateStore_EnsureService(t *testing.T) { func TestStateStore_Services(t *testing.T) { s := testStateStore(t) + // Listing with no results returns an empty list. + ws := memdb.NewWatchSet() + idx, services, err := s.Services(ws) + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 0 { + t.Fatalf("bad: %d", idx) + } + if len(services) != 0 { + t.Fatalf("bad: %v", services) + } + // Register several nodes and services. testRegisterNode(t, s, 1, "node1") ns1 := &structs.NodeService{ @@ -898,9 +819,13 @@ func TestStateStore_Services(t *testing.T) { if err := s.EnsureService(5, "node2", ns2); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Pull all the services. - idx, services, err := s.Services() + ws = memdb.NewWatchSet() + idx, services, err = s.Services(ws) if err != nil { t.Fatalf("err: %s", err) } @@ -921,18 +846,27 @@ func TestStateStore_Services(t *testing.T) { if !reflect.DeepEqual(expected, services) { t.Fatalf("bad: %#v", services) } + + // Deleting a node with a service should fire the watch. + if err := s.DeleteNode(6, "node1"); err != nil { + t.Fatalf("err: %s", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_ServicesByNodeMeta(t *testing.T) { s := testStateStore(t) - // Listing with no results returns nil - idx, res, err := s.ServicesByNodeMeta(map[string]string{"somekey": "somevalue"}) + // Listing with no results returns nil. + ws := memdb.NewWatchSet() + idx, res, err := s.ServicesByNodeMeta(ws, map[string]string{"somekey": "somevalue"}) if idx != 0 || len(res) != 0 || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } - // Create some nodes and services in the state store + // Create some nodes and services in the state store. node0 := &structs.Node{Node: "node0", Address: "127.0.0.1", Meta: map[string]string{"role": "client", "common": "1"}} if err := s.EnsureNode(0, node0); err != nil { t.Fatalf("err: %v", err) @@ -961,9 +895,13 @@ func TestStateStore_ServicesByNodeMeta(t *testing.T) { if err := s.EnsureService(3, "node1", ns2); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } - // Filter the services by the first node's meta value - _, res, err = s.ServicesByNodeMeta(map[string]string{"role": "client"}) + // Filter the services by the first node's meta value. + ws = memdb.NewWatchSet() + _, res, err = s.ServicesByNodeMeta(ws, map[string]string{"role": "client"}) if err != nil { t.Fatalf("err: %s", err) } @@ -976,7 +914,7 @@ func TestStateStore_ServicesByNodeMeta(t *testing.T) { } // Get all services using the common meta value - _, res, err = s.ServicesByNodeMeta(map[string]string{"common": "1"}) + _, res, err = s.ServicesByNodeMeta(ws, map[string]string{"common": "1"}) if err != nil { t.Fatalf("err: %s", err) } @@ -989,7 +927,7 @@ func TestStateStore_ServicesByNodeMeta(t *testing.T) { } // Get an empty list for an invalid meta value - _, res, err = s.ServicesByNodeMeta(map[string]string{"invalid": "nope"}) + _, res, err = s.ServicesByNodeMeta(ws, map[string]string{"invalid": "nope"}) if err != nil { t.Fatalf("err: %s", err) } @@ -999,7 +937,7 @@ func TestStateStore_ServicesByNodeMeta(t *testing.T) { } // Get the first node's service instance using multiple meta filters - _, res, err = s.ServicesByNodeMeta(map[string]string{"role": "client", "common": "1"}) + _, res, err = s.ServicesByNodeMeta(ws, map[string]string{"role": "client", "common": "1"}) if err != nil { t.Fatalf("err: %s", err) } @@ -1010,45 +948,94 @@ func TestStateStore_ServicesByNodeMeta(t *testing.T) { if !reflect.DeepEqual(res, expected) { t.Fatalf("bad: %v %v", res, expected) } + + // Sanity check the watch before we proceed. + if watchFired(ws) { + t.Fatalf("bad") + } + + // Registering some unrelated node + service should not fire the watch. + testRegisterNode(t, s, 4, "nope") + testRegisterService(t, s, 5, "nope", "nope") + if watchFired(ws) { + t.Fatalf("bad") + } + + // Overwhelm the service tracking. + idx = 6 + for i := 0; i < 2*watchLimit; i++ { + node := fmt.Sprintf("many%d", i) + testRegisterNodeWithMeta(t, s, idx, node, map[string]string{"common": "1"}) + idx++ + testRegisterService(t, s, idx, node, "nope") + idx++ + } + + // Now get a fresh watch, which will be forced to watch the whole + // service table. + ws = memdb.NewWatchSet() + _, _, err = s.ServicesByNodeMeta(ws, map[string]string{"common": "1"}) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Registering some unrelated node + service should not fire the watch. + testRegisterService(t, s, idx, "nope", "more-nope") + if !watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_ServiceNodes(t *testing.T) { s := testStateStore(t) + // Listing with no results returns an empty list. + ws := memdb.NewWatchSet() + idx, nodes, err := s.ServiceNodes(ws, "db") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 0 { + t.Fatalf("bad: %d", idx) + } + if len(nodes) != 0 { + t.Fatalf("bad: %v", nodes) + } + + // Create some nodes and services. if err := s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { t.Fatalf("err: %v", err) } - if err := s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { t.Fatalf("err: %v", err) } - if err := s.EnsureService(12, "foo", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}); err != nil { t.Fatalf("err: %v", err) } - if err := s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}); err != nil { t.Fatalf("err: %v", err) } - if err := s.EnsureService(14, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { t.Fatalf("err: %v", err) } - if err := s.EnsureService(15, "bar", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8000}); err != nil { t.Fatalf("err: %v", err) } - if err := s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}); err != nil { t.Fatalf("err: %v", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } - idx, nodes, err := s.ServiceNodes("db") + // Read everything back. + ws = memdb.NewWatchSet() + idx, nodes, err = s.ServiceNodes(ws, "db") if err != nil { t.Fatalf("err: %s", err) } if idx != 16 { - t.Fatalf("bad: %v", 16) + t.Fatalf("bad: %d", idx) } if len(nodes) != 3 { t.Fatalf("bad: %v", nodes) @@ -1068,7 +1055,6 @@ func TestStateStore_ServiceNodes(t *testing.T) { if nodes[0].ServicePort != 8000 { t.Fatalf("bad: %v", nodes) } - if nodes[1].Node != "bar" { t.Fatalf("bad: %v", nodes) } @@ -1084,7 +1070,6 @@ func TestStateStore_ServiceNodes(t *testing.T) { if nodes[1].ServicePort != 8001 { t.Fatalf("bad: %v", nodes) } - if nodes[2].Node != "foo" { t.Fatalf("bad: %v", nodes) } @@ -1100,32 +1085,88 @@ func TestStateStore_ServiceNodes(t *testing.T) { if nodes[2].ServicePort != 8000 { t.Fatalf("bad: %v", nodes) } + + // Registering some unrelated node should not fire the watch. + testRegisterNode(t, s, 17, "nope") + if watchFired(ws) { + t.Fatalf("bad") + } + + // But removing a node with the "db" service should fire the watch. + if err := s.DeleteNode(18, "bar"); err != nil { + t.Fatalf("err: %s", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } + + // Overwhelm the node tracking. + idx = 19 + for i := 0; i < 2*watchLimit; i++ { + node := fmt.Sprintf("many%d", i) + if err := s.EnsureNode(idx, &structs.Node{Node: node, Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + if err := s.EnsureService(idx, node, &structs.NodeService{ID: "db", Service: "db", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + idx++ + } + + // Now get a fresh watch, which will be forced to watch the whole nodes + // table. + ws = memdb.NewWatchSet() + _, _, err = s.ServiceNodes(ws, "db") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Registering some unrelated node should fire the watch now. + testRegisterNode(t, s, idx, "more-nope") + if !watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_ServiceTagNodes(t *testing.T) { s := testStateStore(t) + // Listing with no results returns an empty list. + ws := memdb.NewWatchSet() + idx, nodes, err := s.ServiceTagNodes(ws, "db", "master") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 0 { + t.Fatalf("bad: %d", idx) + } + if len(nodes) != 0 { + t.Fatalf("bad: %v", nodes) + } + + // Create some nodes and services. if err := s.EnsureNode(15, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { t.Fatalf("err: %v", err) } - if err := s.EnsureNode(16, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { t.Fatalf("err: %v", err) } - if err := s.EnsureService(17, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { t.Fatalf("err: %v", err) } - if err := s.EnsureService(18, "foo", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}); err != nil { t.Fatalf("err: %v", err) } - if err := s.EnsureService(19, "bar", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8000}); err != nil { t.Fatalf("err: %v", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } - idx, nodes, err := s.ServiceTagNodes("db", "master") + // Read everything back. + ws = memdb.NewWatchSet() + idx, nodes, err = s.ServiceTagNodes(ws, "db", "master") if err != nil { t.Fatalf("err: %s", err) } @@ -1147,6 +1188,20 @@ func TestStateStore_ServiceTagNodes(t *testing.T) { if nodes[0].ServicePort != 8000 { t.Fatalf("bad: %v", nodes) } + + // Registering some unrelated node should not fire the watch. + testRegisterNode(t, s, 20, "nope") + if watchFired(ws) { + t.Fatalf("bad") + } + + // But removing a node with the "db:master" service should fire the watch. + if err := s.DeleteNode(21, "foo"); err != nil { + t.Fatalf("err: %s", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_ServiceTagNodes_MultipleTags(t *testing.T) { @@ -1172,7 +1227,7 @@ func TestStateStore_ServiceTagNodes_MultipleTags(t *testing.T) { t.Fatalf("err: %v", err) } - idx, nodes, err := s.ServiceTagNodes("db", "master") + idx, nodes, err := s.ServiceTagNodes(nil, "db", "master") if err != nil { t.Fatalf("err: %s", err) } @@ -1195,7 +1250,7 @@ func TestStateStore_ServiceTagNodes_MultipleTags(t *testing.T) { t.Fatalf("bad: %v", nodes) } - idx, nodes, err = s.ServiceTagNodes("db", "v2") + idx, nodes, err = s.ServiceTagNodes(nil, "db", "v2") if err != nil { t.Fatalf("err: %s", err) } @@ -1206,7 +1261,7 @@ func TestStateStore_ServiceTagNodes_MultipleTags(t *testing.T) { t.Fatalf("bad: %v", nodes) } - idx, nodes, err = s.ServiceTagNodes("db", "dev") + idx, nodes, err = s.ServiceTagNodes(nil, "db", "dev") if err != nil { t.Fatalf("err: %s", err) } @@ -1233,18 +1288,24 @@ func TestStateStore_ServiceTagNodes_MultipleTags(t *testing.T) { func TestStateStore_DeleteService(t *testing.T) { s := testStateStore(t) - // Register a node with one service and a check + // Register a node with one service and a check. testRegisterNode(t, s, 1, "node1") testRegisterService(t, s, 2, "node1", "service1") testRegisterCheck(t, s, 3, "node1", "service1", "check1", structs.HealthPassing) - // Delete the service + // Delete the service. + ws := memdb.NewWatchSet() + _, _, err := s.NodeServices(ws, "node1") if err := s.DeleteService(4, "node1", "service1"); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Service doesn't exist. - _, ns, err := s.NodeServices("node1") + ws = memdb.NewWatchSet() + _, ns, err := s.NodeServices(ws, "node1") if err != nil || ns == nil || len(ns.Services) != 0 { t.Fatalf("bad: %#v (err: %#v)", ns, err) } @@ -1258,7 +1319,7 @@ func TestStateStore_DeleteService(t *testing.T) { t.Fatalf("bad: %#v (err: %s)", check, err) } - // Index tables were updated + // Index tables were updated. if idx := s.maxIndex("services"); idx != 4 { t.Fatalf("bad index: %d", idx) } @@ -1267,13 +1328,16 @@ func TestStateStore_DeleteService(t *testing.T) { } // Deleting a nonexistent service should be idempotent and not return an - // error + // error, nor fire a watch. if err := s.DeleteService(5, "node1", "service1"); err != nil { t.Fatalf("err: %s", err) } if idx := s.maxIndex("services"); idx != 4 { t.Fatalf("bad index: %d", idx) } + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_Service_Snapshot(t *testing.T) { @@ -1339,43 +1403,6 @@ func TestStateStore_Service_Snapshot(t *testing.T) { } } -func TestStateStore_Service_Watches(t *testing.T) { - s := testStateStore(t) - - testRegisterNode(t, s, 0, "node1") - ns := &structs.NodeService{ - ID: "service2", - Service: "nomad", - Address: "1.1.1.2", - Port: 8000, - } - - // Call functions that update the services table and make sure a watch - // fires each time. - verifyWatch(t, s.getTableWatch("services"), func() { - if err := s.EnsureService(2, "node1", ns); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("services"), func() { - if err := s.DeleteService(3, "node1", "service2"); err != nil { - t.Fatalf("err: %s", err) - } - }) - - // Check that a delete of a service + check triggers both tables in one - // shot. - testRegisterService(t, s, 4, "node1", "service1") - testRegisterCheck(t, s, 5, "node1", "service1", "check3", structs.HealthPassing) - verifyWatch(t, s.getTableWatch("services"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - if err := s.DeleteService(6, "node1", "service1"); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) -} - func TestStateStore_EnsureCheck(t *testing.T) { s := testStateStore(t) @@ -1413,7 +1440,7 @@ func TestStateStore_EnsureCheck(t *testing.T) { } // Retrieve the check and make sure it matches - idx, checks, err := s.NodeChecks("node1") + idx, checks, err := s.NodeChecks(nil, "node1") if err != nil { t.Fatalf("err: %s", err) } @@ -1434,7 +1461,7 @@ func TestStateStore_EnsureCheck(t *testing.T) { } // Check that we successfully updated - idx, checks, err = s.NodeChecks("node1") + idx, checks, err = s.NodeChecks(nil, "node1") if err != nil { t.Fatalf("err: %s", err) } @@ -1474,7 +1501,7 @@ func TestStateStore_EnsureCheck_defaultStatus(t *testing.T) { } // Get the check again - _, result, err := s.NodeChecks("node1") + _, result, err := s.NodeChecks(nil, "node1") if err != nil { t.Fatalf("err: %s", err) } @@ -1488,19 +1515,34 @@ func TestStateStore_EnsureCheck_defaultStatus(t *testing.T) { func TestStateStore_NodeChecks(t *testing.T) { s := testStateStore(t) - // Create the first node and service with some checks + // Do an initial query for a node that doesn't exist. + ws := memdb.NewWatchSet() + idx, checks, err := s.NodeChecks(ws, "node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 0 { + t.Fatalf("bad: %d", idx) + } + if len(checks) != 0 { + t.Fatalf("bad: %#v", checks) + } + + // Create some nodes and checks. testRegisterNode(t, s, 0, "node1") testRegisterService(t, s, 1, "node1", "service1") testRegisterCheck(t, s, 2, "node1", "service1", "check1", structs.HealthPassing) testRegisterCheck(t, s, 3, "node1", "service1", "check2", structs.HealthPassing) - - // Create a second node/service with a different set of checks testRegisterNode(t, s, 4, "node2") testRegisterService(t, s, 5, "node2", "service2") testRegisterCheck(t, s, 6, "node2", "service2", "check3", structs.HealthPassing) + if !watchFired(ws) { + t.Fatalf("bad") + } // Try querying for all checks associated with node1 - idx, checks, err := s.NodeChecks("node1") + ws = memdb.NewWatchSet() + idx, checks, err = s.NodeChecks(ws, "node1") if err != nil { t.Fatalf("err: %s", err) } @@ -1511,35 +1553,64 @@ func TestStateStore_NodeChecks(t *testing.T) { t.Fatalf("bad checks: %#v", checks) } + // Creating some unrelated node should not fire the watch. + testRegisterNode(t, s, 7, "node3") + testRegisterCheck(t, s, 8, "node3", "", "check1", structs.HealthPassing) + if watchFired(ws) { + t.Fatalf("bad") + } + // Try querying for all checks associated with node2 - idx, checks, err = s.NodeChecks("node2") + ws = memdb.NewWatchSet() + idx, checks, err = s.NodeChecks(ws, "node2") if err != nil { t.Fatalf("err: %s", err) } - if idx != 6 { + if idx != 8 { t.Fatalf("bad index: %d", idx) } if len(checks) != 1 || checks[0].CheckID != "check3" { t.Fatalf("bad checks: %#v", checks) } + + // Changing node2 should fire the watch. + testRegisterCheck(t, s, 9, "node2", "service2", "check3", structs.HealthCritical) + if !watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_ServiceChecks(t *testing.T) { s := testStateStore(t) - // Create the first node and service with some checks + // Do an initial query for a service that doesn't exist. + ws := memdb.NewWatchSet() + idx, checks, err := s.ServiceChecks(ws, "service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 0 { + t.Fatalf("bad: %d", idx) + } + if len(checks) != 0 { + t.Fatalf("bad: %#v", checks) + } + + // Create some nodes and checks. testRegisterNode(t, s, 0, "node1") testRegisterService(t, s, 1, "node1", "service1") testRegisterCheck(t, s, 2, "node1", "service1", "check1", structs.HealthPassing) testRegisterCheck(t, s, 3, "node1", "service1", "check2", structs.HealthPassing) - - // Create a second node/service with a different set of checks testRegisterNode(t, s, 4, "node2") testRegisterService(t, s, 5, "node2", "service2") testRegisterCheck(t, s, 6, "node2", "service2", "check3", structs.HealthPassing) + if !watchFired(ws) { + t.Fatalf("bad") + } - // Try querying for all checks associated with service1 - idx, checks, err := s.ServiceChecks("service1") + // Try querying for all checks associated with service1. + ws = memdb.NewWatchSet() + idx, checks, err = s.ServiceChecks(ws, "service1") if err != nil { t.Fatalf("err: %s", err) } @@ -1549,21 +1620,48 @@ func TestStateStore_ServiceChecks(t *testing.T) { if len(checks) != 2 || checks[0].CheckID != "check1" || checks[1].CheckID != "check2" { t.Fatalf("bad checks: %#v", checks) } + + // Adding some unrelated service + check should not fire the watch. + testRegisterService(t, s, 7, "node1", "service3") + testRegisterCheck(t, s, 8, "node1", "service3", "check3", structs.HealthPassing) + if watchFired(ws) { + t.Fatalf("bad") + } + + // Updating a related check should fire the watch. + testRegisterCheck(t, s, 9, "node1", "service1", "check2", structs.HealthCritical) + if !watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_ServiceChecksByNodeMeta(t *testing.T) { s := testStateStore(t) - // Create the first node and service with some checks + // Querying with no results returns nil. + ws := memdb.NewWatchSet() + idx, checks, err := s.ServiceChecksByNodeMeta(ws, "service1", nil) + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 0 { + t.Fatalf("bad: %d", idx) + } + if len(checks) != 0 { + t.Fatalf("bad: %#v", checks) + } + + // Create some nodes and checks. testRegisterNodeWithMeta(t, s, 0, "node1", map[string]string{"somekey": "somevalue", "common": "1"}) testRegisterService(t, s, 1, "node1", "service1") testRegisterCheck(t, s, 2, "node1", "service1", "check1", structs.HealthPassing) testRegisterCheck(t, s, 3, "node1", "service1", "check2", structs.HealthPassing) - - // Create a second node/service with a different set of checks testRegisterNodeWithMeta(t, s, 4, "node2", map[string]string{"common": "1"}) testRegisterService(t, s, 5, "node2", "service1") testRegisterCheck(t, s, 6, "node2", "service1", "check3", structs.HealthPassing) + if !watchFired(ws) { + t.Fatalf("bad") + } cases := []struct { filters map[string]string @@ -1591,9 +1689,11 @@ func TestStateStore_ServiceChecksByNodeMeta(t *testing.T) { }, } - // Try querying for all checks associated with service1 + // Try querying for all checks associated with service1. + idx = 7 for _, tc := range cases { - _, checks, err := s.ServiceChecksByNodeMeta("service1", tc.filters) + ws = memdb.NewWatchSet() + _, checks, err := s.ServiceChecksByNodeMeta(ws, "service1", tc.filters) if err != nil { t.Fatalf("err: %s", err) } @@ -1605,6 +1705,39 @@ func TestStateStore_ServiceChecksByNodeMeta(t *testing.T) { t.Fatalf("bad checks: %#v", checks) } } + + // Registering some unrelated node should not fire the watch. + testRegisterNode(t, s, idx, fmt.Sprintf("nope%d", idx)) + idx++ + if watchFired(ws) { + t.Fatalf("bad") + } + } + + // Overwhelm the node tracking. + for i := 0; i < 2*watchLimit; i++ { + node := fmt.Sprintf("many%d", idx) + testRegisterNodeWithMeta(t, s, idx, node, map[string]string{"common": "1"}) + idx++ + testRegisterService(t, s, idx, node, "service1") + idx++ + testRegisterCheck(t, s, idx, node, "service1", "check1", structs.HealthPassing) + idx++ + } + + // Now get a fresh watch, which will be forced to watch the whole + // node table. + ws = memdb.NewWatchSet() + _, _, err = s.ServiceChecksByNodeMeta(ws, "service1", + map[string]string{"common": "1"}) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Registering some unrelated node should now fire the watch. + testRegisterNode(t, s, idx, "nope") + if !watchFired(ws) { + t.Fatalf("bad") } } @@ -1612,7 +1745,8 @@ func TestStateStore_ChecksInState(t *testing.T) { s := testStateStore(t) // Querying with no results returns nil - idx, res, err := s.ChecksInState(structs.HealthPassing) + ws := memdb.NewWatchSet() + idx, res, err := s.ChecksInState(ws, structs.HealthPassing) if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } @@ -1622,9 +1756,13 @@ func TestStateStore_ChecksInState(t *testing.T) { testRegisterCheck(t, s, 1, "node1", "", "check1", structs.HealthPassing) testRegisterCheck(t, s, 2, "node1", "", "check2", structs.HealthCritical) testRegisterCheck(t, s, 3, "node1", "", "check3", structs.HealthPassing) + if !watchFired(ws) { + t.Fatalf("bad") + } // Query the state store for passing checks. - _, checks, err := s.ChecksInState(structs.HealthPassing) + ws = memdb.NewWatchSet() + _, checks, err := s.ChecksInState(ws, structs.HealthPassing) if err != nil { t.Fatalf("err: %s", err) } @@ -1636,33 +1774,55 @@ func TestStateStore_ChecksInState(t *testing.T) { if checks[0].CheckID != "check1" || checks[1].CheckID != "check3" { t.Fatalf("bad: %#v", checks) } + if watchFired(ws) { + t.Fatalf("bad") + } + + // Changing the state of a check should fire the watch. + testRegisterCheck(t, s, 4, "node1", "", "check1", structs.HealthCritical) + if !watchFired(ws) { + t.Fatalf("bad") + } // HealthAny just returns everything. - _, checks, err = s.ChecksInState(structs.HealthAny) + ws = memdb.NewWatchSet() + _, checks, err = s.ChecksInState(ws, structs.HealthAny) if err != nil { t.Fatalf("err: %s", err) } if n := len(checks); n != 3 { t.Fatalf("expected 3 checks, got: %d", n) } + if watchFired(ws) { + t.Fatalf("bad") + } + + // Adding a new check should fire the watch. + testRegisterCheck(t, s, 5, "node1", "", "check4", structs.HealthCritical) + if !watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_ChecksInStateByNodeMeta(t *testing.T) { s := testStateStore(t) - // Querying with no results returns nil - idx, res, err := s.ChecksInStateByNodeMeta(structs.HealthPassing, nil) + // Querying with no results returns nil. + ws := memdb.NewWatchSet() + idx, res, err := s.ChecksInStateByNodeMeta(ws, structs.HealthPassing, nil) if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } - // Register a node with checks in varied states + // Register a node with checks in varied states. testRegisterNodeWithMeta(t, s, 0, "node1", map[string]string{"somekey": "somevalue", "common": "1"}) testRegisterCheck(t, s, 1, "node1", "", "check1", structs.HealthPassing) testRegisterCheck(t, s, 2, "node1", "", "check2", structs.HealthCritical) - testRegisterNodeWithMeta(t, s, 3, "node2", map[string]string{"common": "1"}) testRegisterCheck(t, s, 4, "node2", "", "check3", structs.HealthPassing) + if !watchFired(ws) { + t.Fatalf("bad") + } cases := []struct { filters map[string]string @@ -1712,9 +1872,11 @@ func TestStateStore_ChecksInStateByNodeMeta(t *testing.T) { }, } - // Try querying for all checks associated with service1 + // Try querying for all checks associated with service1. + idx = 5 for _, tc := range cases { - _, checks, err := s.ChecksInStateByNodeMeta(tc.state, tc.filters) + ws = memdb.NewWatchSet() + _, checks, err := s.ChecksInStateByNodeMeta(ws, tc.state, tc.filters) if err != nil { t.Fatalf("err: %s", err) } @@ -1726,23 +1888,70 @@ func TestStateStore_ChecksInStateByNodeMeta(t *testing.T) { t.Fatalf("bad checks: %#v, %v", checks, tc.checks) } } + + // Registering some unrelated node should not fire the watch. + testRegisterNode(t, s, idx, fmt.Sprintf("nope%d", idx)) + idx++ + if watchFired(ws) { + t.Fatalf("bad") + } + } + + // Overwhelm the node tracking. + for i := 0; i < 2*watchLimit; i++ { + node := fmt.Sprintf("many%d", idx) + testRegisterNodeWithMeta(t, s, idx, node, map[string]string{"common": "1"}) + idx++ + testRegisterService(t, s, idx, node, "service1") + idx++ + testRegisterCheck(t, s, idx, node, "service1", "check1", structs.HealthPassing) + idx++ + } + + // Now get a fresh watch, which will be forced to watch the whole + // node table. + ws = memdb.NewWatchSet() + _, _, err = s.ChecksInStateByNodeMeta(ws, structs.HealthPassing, + map[string]string{"common": "1"}) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Registering some unrelated node should now fire the watch. + testRegisterNode(t, s, idx, "nope") + if !watchFired(ws) { + t.Fatalf("bad") } } func TestStateStore_DeleteCheck(t *testing.T) { s := testStateStore(t) - // Register a node and a node-level health check + // Register a node and a node-level health check. testRegisterNode(t, s, 1, "node1") testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) - // Delete the check + // Make sure the check is there. + ws := memdb.NewWatchSet() + _, checks, err := s.NodeChecks(ws, "node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(checks) != 1 { + t.Fatalf("bad: %#v", checks) + } + + // Delete the check. if err := s.DeleteCheck(3, "node1", "check1"); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Check is gone - _, checks, err := s.NodeChecks("node1") + ws = memdb.NewWatchSet() + _, checks, err = s.NodeChecks(ws, "node1") if err != nil { t.Fatalf("err: %s", err) } @@ -1750,50 +1959,59 @@ func TestStateStore_DeleteCheck(t *testing.T) { t.Fatalf("bad: %#v", checks) } - // Index tables were updated + // Index tables were updated. if idx := s.maxIndex("checks"); idx != 3 { t.Fatalf("bad index: %d", idx) } // Deleting a nonexistent check should be idempotent and not return an - // error + // error. if err := s.DeleteCheck(4, "node1", "check1"); err != nil { t.Fatalf("err: %s", err) } if idx := s.maxIndex("checks"); idx != 3 { t.Fatalf("bad index: %d", idx) } + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_CheckServiceNodes(t *testing.T) { s := testStateStore(t) // Querying with no matches gives an empty response - idx, res, err := s.CheckServiceNodes("service1") + ws := memdb.NewWatchSet() + idx, res, err := s.CheckServiceNodes(ws, "service1") if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } - // Register some nodes + // Register some nodes. testRegisterNode(t, s, 0, "node1") testRegisterNode(t, s, 1, "node2") - // Register node-level checks. These should not be returned - // in the final result. + // Register node-level checks. These should be the final result. testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) testRegisterCheck(t, s, 3, "node2", "", "check2", structs.HealthPassing) - // Register a service against the nodes + // Register a service against the nodes. testRegisterService(t, s, 4, "node1", "service1") testRegisterService(t, s, 5, "node2", "service2") - // Register checks against the services + // Register checks against the services. testRegisterCheck(t, s, 6, "node1", "service1", "check3", structs.HealthPassing) testRegisterCheck(t, s, 7, "node2", "service2", "check4", structs.HealthPassing) - // Query the state store for nodes and checks which - // have been registered with a specific service. - idx, results, err := s.CheckServiceNodes("service1") + // At this point all the changes should have fired the watch. + if !watchFired(ws) { + t.Fatalf("bad") + } + + // Query the state store for nodes and checks which have been registered + // with a specific service. + ws = memdb.NewWatchSet() + idx, results, err := s.CheckServiceNodes(ws, "service1") if err != nil { t.Fatalf("err: %s", err) } @@ -1801,18 +2019,24 @@ func TestStateStore_CheckServiceNodes(t *testing.T) { t.Fatalf("bad index: %d", idx) } - // Make sure we get the expected result (service check + node check) + // Make sure we get the expected result (service check + node check). if n := len(results); n != 1 { t.Fatalf("expected 1 result, got: %d", n) } csn := results[0] - if csn.Node == nil || csn.Service == nil || len(csn.Checks) != 2 { + if csn.Node == nil || csn.Service == nil || len(csn.Checks) != 2 || + csn.Checks[0].ServiceID != "" || csn.Checks[0].CheckID != "check1" || + csn.Checks[1].ServiceID != "service1" || csn.Checks[1].CheckID != "check3" { t.Fatalf("bad output: %#v", csn) } - // Node updates alter the returned index + // Node updates alter the returned index and fire the watch. testRegisterNode(t, s, 8, "node1") - idx, results, err = s.CheckServiceNodes("service1") + if !watchFired(ws) { + t.Fatalf("bad") + } + ws = memdb.NewWatchSet() + idx, results, err = s.CheckServiceNodes(ws, "service1") if err != nil { t.Fatalf("err: %s", err) } @@ -1820,9 +2044,13 @@ func TestStateStore_CheckServiceNodes(t *testing.T) { t.Fatalf("bad index: %d", idx) } - // Service updates alter the returned index + // Service updates alter the returned index and fire the watch. testRegisterService(t, s, 9, "node1", "service1") - idx, results, err = s.CheckServiceNodes("service1") + if !watchFired(ws) { + t.Fatalf("bad") + } + ws = memdb.NewWatchSet() + idx, results, err = s.CheckServiceNodes(ws, "service1") if err != nil { t.Fatalf("err: %s", err) } @@ -1830,15 +2058,64 @@ func TestStateStore_CheckServiceNodes(t *testing.T) { t.Fatalf("bad index: %d", idx) } - // Check updates alter the returned index + // Check updates alter the returned index and fire the watch. testRegisterCheck(t, s, 10, "node1", "service1", "check1", structs.HealthCritical) - idx, results, err = s.CheckServiceNodes("service1") + if !watchFired(ws) { + t.Fatalf("bad") + } + ws = memdb.NewWatchSet() + idx, results, err = s.CheckServiceNodes(ws, "service1") if err != nil { t.Fatalf("err: %s", err) } if idx != 10 { t.Fatalf("bad index: %d", idx) } + + // Registering some unrelated node + service should not fire the watch. + testRegisterNode(t, s, 11, "nope") + testRegisterService(t, s, 12, "nope", "nope") + if watchFired(ws) { + t.Fatalf("bad") + } + + // Overwhelm node and check tracking. + idx = 13 + for i := 0; i < 2*watchLimit; i++ { + node := fmt.Sprintf("many%d", i) + testRegisterNode(t, s, idx, node) + idx++ + testRegisterCheck(t, s, idx, node, "", "check1", structs.HealthPassing) + idx++ + testRegisterService(t, s, idx, node, "service1") + idx++ + testRegisterCheck(t, s, idx, node, "service1", "check2", structs.HealthPassing) + idx++ + } + + // Now registering an unrelated node will fire the watch. + ws = memdb.NewWatchSet() + idx, results, err = s.CheckServiceNodes(ws, "service1") + if err != nil { + t.Fatalf("err: %s", err) + } + testRegisterNode(t, s, idx, "more-nope") + idx++ + if !watchFired(ws) { + t.Fatalf("bad") + } + + // Also, registering an unrelated check will fire the watch. + ws = memdb.NewWatchSet() + idx, results, err = s.CheckServiceNodes(ws, "service1") + if err != nil { + t.Fatalf("err: %s", err) + } + testRegisterCheck(t, s, idx, "more-nope", "", "check1", structs.HealthPassing) + idx++ + if !watchFired(ws) { + t.Fatalf("bad") + } } func BenchmarkCheckServiceNodes(b *testing.B) { @@ -1873,8 +2150,9 @@ func BenchmarkCheckServiceNodes(b *testing.B) { b.Fatalf("err: %v", err) } + ws := memdb.NewWatchSet() for i := 0; i < b.N; i++ { - s.CheckServiceNodes("db") + s.CheckServiceNodes(ws, "db") } } @@ -1907,7 +2185,8 @@ func TestStateStore_CheckServiceTagNodes(t *testing.T) { t.Fatalf("err: %v", err) } - idx, nodes, err := s.CheckServiceTagNodes("db", "master") + ws := memdb.NewWatchSet() + idx, nodes, err := s.CheckServiceTagNodes(ws, "db", "master") if err != nil { t.Fatalf("err: %s", err) } @@ -1932,6 +2211,14 @@ func TestStateStore_CheckServiceTagNodes(t *testing.T) { if nodes[0].Checks[1].CheckID != "db" { t.Fatalf("Bad: %v", nodes[0]) } + + // Changing a tag should fire the watch. + if err := s.EnsureService(4, "foo", &structs.NodeService{ID: "db1", Service: "db", Tags: []string{"nope"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_Check_Snapshot(t *testing.T) { @@ -1998,45 +2285,17 @@ func TestStateStore_Check_Snapshot(t *testing.T) { } } -func TestStateStore_Check_Watches(t *testing.T) { - s := testStateStore(t) - - testRegisterNode(t, s, 0, "node1") - hc := &structs.HealthCheck{ - Node: "node1", - CheckID: "check1", - Status: structs.HealthPassing, - } - - // Call functions that update the checks table and make sure a watch fires - // each time. - verifyWatch(t, s.getTableWatch("checks"), func() { - if err := s.EnsureCheck(1, hc); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("checks"), func() { - hc.Status = structs.HealthCritical - if err := s.EnsureCheck(2, hc); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("checks"), func() { - if err := s.DeleteCheck(3, "node1", "check1"); err != nil { - t.Fatalf("err: %s", err) - } - }) -} - func TestStateStore_NodeInfo_NodeDump(t *testing.T) { s := testStateStore(t) // Generating a node dump that matches nothing returns empty - idx, dump, err := s.NodeInfo("node1") + wsInfo := memdb.NewWatchSet() + idx, dump, err := s.NodeInfo(wsInfo, "node1") if idx != 0 || dump != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, dump, err) } - idx, dump, err = s.NodeDump() + wsDump := memdb.NewWatchSet() + idx, dump, err = s.NodeDump(wsDump) if idx != 0 || dump != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, dump, err) } @@ -2059,6 +2318,14 @@ func TestStateStore_NodeInfo_NodeDump(t *testing.T) { testRegisterCheck(t, s, 8, "node1", "", "check2", structs.HealthPassing) testRegisterCheck(t, s, 9, "node2", "", "check2", structs.HealthPassing) + // Both watches should have fired due to the changes above. + if !watchFired(wsInfo) { + t.Fatalf("bad") + } + if !watchFired(wsDump) { + t.Fatalf("bad") + } + // Check that our result matches what we expect. expect := structs.NodeDump{ &structs.NodeInfo{ @@ -2162,7 +2429,8 @@ func TestStateStore_NodeInfo_NodeDump(t *testing.T) { } // Get a dump of just a single node - idx, dump, err = s.NodeInfo("node1") + ws := memdb.NewWatchSet() + idx, dump, err = s.NodeInfo(ws, "node1") if err != nil { t.Fatalf("err: %s", err) } @@ -2174,7 +2442,7 @@ func TestStateStore_NodeInfo_NodeDump(t *testing.T) { } // Generate a dump of all the nodes - idx, dump, err = s.NodeDump() + idx, dump, err = s.NodeDump(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -2184,4 +2452,12 @@ func TestStateStore_NodeInfo_NodeDump(t *testing.T) { if !reflect.DeepEqual(dump, expect) { t.Fatalf("bad: %#v", dump[0].Services[0]) } + + // Registering some unrelated node + service + check should not fire the + // watch. + testRegisterNode(t, s, 10, "nope") + testRegisterService(t, s, 11, "nope", "nope") + if watchFired(ws) { + t.Fatalf("bad") + } } diff --git a/consul/state/coordinate.go b/consul/state/coordinate.go index 376d02b6c5..6cfba415ea 100644 --- a/consul/state/coordinate.go +++ b/consul/state/coordinate.go @@ -31,7 +31,6 @@ func (s *StateRestore) Coordinates(idx uint64, updates structs.Coordinates) erro return fmt.Errorf("failed updating index: %s", err) } - s.watches.Arm("coordinates") return nil } @@ -58,20 +57,22 @@ func (s *StateStore) CoordinateGetRaw(node string) (*coordinate.Coordinate, erro } // Coordinates queries for all nodes with coordinates. -func (s *StateStore) Coordinates() (uint64, structs.Coordinates, error) { +func (s *StateStore) Coordinates(ws memdb.WatchSet) (uint64, structs.Coordinates, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("Coordinates")...) + idx := maxIndexTxn(tx, "coordinates") // Pull all the coordinates. - coords, err := tx.Get("coordinates", "id") + iter, err := tx.Get("coordinates", "id") if err != nil { return 0, nil, fmt.Errorf("failed coordinate lookup: %s", err) } + ws.Add(iter.WatchCh()) + var results structs.Coordinates - for coord := coords.Next(); coord != nil; coord = coords.Next() { + for coord := iter.Next(); coord != nil; coord = iter.Next() { results = append(results, coord.(*structs.Coordinate)) } return idx, results, nil @@ -111,7 +112,6 @@ func (s *StateStore) CoordinateBatchUpdate(idx uint64, updates structs.Coordinat return fmt.Errorf("failed updating index: %s", err) } - tx.Defer(func() { s.tableWatches["coordinates"].Notify() }) tx.Commit() return nil } diff --git a/consul/state/coordinate_test.go b/consul/state/coordinate_test.go index 1998333845..c8af3a9f41 100644 --- a/consul/state/coordinate_test.go +++ b/consul/state/coordinate_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" "github.com/hashicorp/serf/coordinate" ) @@ -29,7 +30,8 @@ func TestStateStore_Coordinate_Updates(t *testing.T) { // Make sure the coordinates list starts out empty, and that a query for // a raw coordinate for a nonexistent node doesn't do anything bad. - idx, coords, err := s.Coordinates() + ws := memdb.NewWatchSet() + idx, coords, err := s.Coordinates(ws) if err != nil { t.Fatalf("err: %s", err) } @@ -62,10 +64,14 @@ func TestStateStore_Coordinate_Updates(t *testing.T) { if err := s.CoordinateBatchUpdate(1, updates); err != nil { t.Fatalf("err: %s", err) } + if watchFired(ws) { + t.Fatalf("bad") + } // Should still be empty, though applying an empty batch does bump // the table index. - idx, coords, err = s.Coordinates() + ws = memdb.NewWatchSet() + idx, coords, err = s.Coordinates(ws) if err != nil { t.Fatalf("err: %s", err) } @@ -82,9 +88,13 @@ func TestStateStore_Coordinate_Updates(t *testing.T) { if err := s.CoordinateBatchUpdate(3, updates); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Should go through now. - idx, coords, err = s.Coordinates() + ws = memdb.NewWatchSet() + idx, coords, err = s.Coordinates(ws) if err != nil { t.Fatalf("err: %s", err) } @@ -111,9 +121,12 @@ func TestStateStore_Coordinate_Updates(t *testing.T) { if err := s.CoordinateBatchUpdate(4, updates); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Verify it got applied. - idx, coords, err = s.Coordinates() + idx, coords, err = s.Coordinates(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -175,7 +188,7 @@ func TestStateStore_Coordinate_Cleanup(t *testing.T) { } // Make sure the index got updated. - idx, coords, err := s.Coordinates() + idx, coords, err := s.Coordinates(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -252,7 +265,7 @@ func TestStateStore_Coordinate_Snapshot_Restore(t *testing.T) { restore.Commit() // Read the restored coordinates back out and verify that they match. - idx, res, err := s.Coordinates() + idx, res, err := s.Coordinates(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -271,28 +284,3 @@ func TestStateStore_Coordinate_Snapshot_Restore(t *testing.T) { }() } - -func TestStateStore_Coordinate_Watches(t *testing.T) { - s := testStateStore(t) - - testRegisterNode(t, s, 1, "node1") - - // Call functions that update the coordinates table and make sure a watch fires - // each time. - verifyWatch(t, s.getTableWatch("coordinates"), func() { - updates := structs.Coordinates{ - &structs.Coordinate{ - Node: "node1", - Coord: generateRandomCoordinate(), - }, - } - if err := s.CoordinateBatchUpdate(2, updates); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("coordinates"), func() { - if err := s.DeleteNode(3, "node1"); err != nil { - t.Fatalf("err: %s", err) - } - }) -} diff --git a/consul/state/kvs.go b/consul/state/kvs.go index 3dccdebd31..c111380a9a 100644 --- a/consul/state/kvs.go +++ b/consul/state/kvs.go @@ -32,9 +32,6 @@ func (s *StateRestore) KVS(entry *structs.DirEntry) error { if err := indexUpdateMaxTxn(s.tx, entry.ModifyIndex, "kvs"); err != nil { return fmt.Errorf("failed updating index: %s", err) } - - // We have a single top-level KVS watch trigger instead of doing - // tons of prefix watches. return nil } @@ -114,29 +111,29 @@ func (s *StateStore) kvsSetTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntr return fmt.Errorf("failed updating index: %s", err) } - tx.Defer(func() { s.kvsWatch.Notify(entry.Key, false) }) return nil } // KVSGet is used to retrieve a key/value pair from the state store. -func (s *StateStore) KVSGet(key string) (uint64, *structs.DirEntry, error) { +func (s *StateStore) KVSGet(ws memdb.WatchSet, key string) (uint64, *structs.DirEntry, error) { tx := s.db.Txn(false) defer tx.Abort() - return s.kvsGetTxn(tx, key) + return s.kvsGetTxn(tx, ws, key) } // kvsGetTxn is the inner method that gets a KVS entry inside an existing // transaction. -func (s *StateStore) kvsGetTxn(tx *memdb.Txn, key string) (uint64, *structs.DirEntry, error) { +func (s *StateStore) kvsGetTxn(tx *memdb.Txn, ws memdb.WatchSet, key string) (uint64, *structs.DirEntry, error) { // Get the table index. idx := maxIndexTxn(tx, "kvs", "tombstones") // Retrieve the key. - entry, err := tx.First("kvs", "id", key) + watchCh, entry, err := tx.FirstWatch("kvs", "id", key) if err != nil { return 0, nil, fmt.Errorf("failed kvs lookup: %s", err) } + ws.Add(watchCh) if entry != nil { return idx, entry.(*structs.DirEntry), nil } @@ -147,16 +144,16 @@ func (s *StateStore) kvsGetTxn(tx *memdb.Txn, key string) (uint64, *structs.DirE // prefix is left empty, all keys in the KVS will be returned. The returned // is the max index of the returned kvs entries or applicable tombstones, or // else it's the full table indexes for kvs and tombstones. -func (s *StateStore) KVSList(prefix string) (uint64, structs.DirEntries, error) { +func (s *StateStore) KVSList(ws memdb.WatchSet, prefix string) (uint64, structs.DirEntries, error) { tx := s.db.Txn(false) defer tx.Abort() - return s.kvsListTxn(tx, prefix) + return s.kvsListTxn(tx, ws, prefix) } // kvsListTxn is the inner method that gets a list of KVS entries matching a // prefix. -func (s *StateStore) kvsListTxn(tx *memdb.Txn, prefix string) (uint64, structs.DirEntries, error) { +func (s *StateStore) kvsListTxn(tx *memdb.Txn, ws memdb.WatchSet, prefix string) (uint64, structs.DirEntries, error) { // Get the table indexes. idx := maxIndexTxn(tx, "kvs", "tombstones") @@ -165,6 +162,7 @@ func (s *StateStore) kvsListTxn(tx *memdb.Txn, prefix string) (uint64, structs.D if err != nil { return 0, nil, fmt.Errorf("failed kvs lookup: %s", err) } + ws.Add(entries.WatchCh()) // Gather all of the keys found in the store var ents structs.DirEntries @@ -203,7 +201,7 @@ func (s *StateStore) kvsListTxn(tx *memdb.Txn, prefix string) (uint64, structs.D // An optional separator may be specified, which can be used to slice off a part // of the response so that only a subset of the prefix is returned. In this // mode, the keys which are omitted are still counted in the returned index. -func (s *StateStore) KVSListKeys(prefix, sep string) (uint64, []string, error) { +func (s *StateStore) KVSListKeys(ws memdb.WatchSet, prefix, sep string) (uint64, []string, error) { tx := s.db.Txn(false) defer tx.Abort() @@ -215,6 +213,7 @@ func (s *StateStore) KVSListKeys(prefix, sep string) (uint64, []string, error) { if err != nil { return 0, nil, fmt.Errorf("failed kvs lookup: %s", err) } + ws.Add(entries.WatchCh()) prefixLen := len(prefix) sepLen := len(sep) @@ -313,7 +312,6 @@ func (s *StateStore) kvsDeleteTxn(tx *memdb.Txn, idx uint64, key string) error { return fmt.Errorf("failed updating index: %s", err) } - tx.Defer(func() { s.kvsWatch.Notify(key, false) }) return nil } @@ -452,7 +450,6 @@ func (s *StateStore) kvsDeleteTreeTxn(tx *memdb.Txn, idx uint64, prefix string) // Update the index if modified { - tx.Defer(func() { s.kvsWatch.Notify(prefix, true) }) if err := tx.Insert("index", &IndexEntry{"kvs", idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) } diff --git a/consul/state/kvs_test.go b/consul/state/kvs_test.go index bd8996a014..e3134563cc 100644 --- a/consul/state/kvs_test.go +++ b/consul/state/kvs_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" ) func TestStateStore_GC(t *testing.T) { @@ -121,7 +122,7 @@ func TestStateStore_ReapTombstones(t *testing.T) { // Pull out the list and check the index, which should come from the // tombstones. - idx, _, err := s.KVSList("foo/") + idx, _, err := s.KVSList(nil, "foo/") if err != nil { t.Fatalf("err: %s", err) } @@ -135,7 +136,7 @@ func TestStateStore_ReapTombstones(t *testing.T) { } // Should still be good because 7 is in there. - idx, _, err = s.KVSList("foo/") + idx, _, err = s.KVSList(nil, "foo/") if err != nil { t.Fatalf("err: %s", err) } @@ -149,7 +150,7 @@ func TestStateStore_ReapTombstones(t *testing.T) { } // At this point the sub index will slide backwards. - idx, _, err = s.KVSList("foo/") + idx, _, err = s.KVSList(nil, "foo/") if err != nil { t.Fatalf("err: %s", err) } @@ -173,7 +174,8 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { s := testStateStore(t) // Get on an nonexistent key returns nil. - idx, result, err := s.KVSGet("foo") + ws := memdb.NewWatchSet() + idx, result, err := s.KVSGet(ws, "foo") if result != nil || err != nil || idx != 0 { t.Fatalf("expected (0, nil, nil), got : (%#v, %#v, %#v)", idx, result, err) } @@ -186,9 +188,13 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { if err := s.KVSSet(1, entry); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Retrieve the K/V entry again. - idx, result, err = s.KVSGet("foo") + ws = memdb.NewWatchSet() + idx, result, err = s.KVSGet(ws, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -217,9 +223,13 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { if err := s.KVSSet(2, update); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Fetch the kv pair and check. - idx, result, err = s.KVSGet("foo") + ws = memdb.NewWatchSet() + idx, result, err = s.KVSGet(ws, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -242,9 +252,13 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { if err := s.KVSSet(3, update); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Fetch the kv pair and check. - idx, result, err = s.KVSGet("foo") + ws = memdb.NewWatchSet() + idx, result, err = s.KVSGet(ws, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -276,9 +290,13 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { if !ok || err != nil { t.Fatalf("didn't get the lock: %v %s", ok, err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Fetch the kv pair and check. - idx, result, err = s.KVSGet("foo") + ws = memdb.NewWatchSet() + idx, result, err = s.KVSGet(ws, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -304,9 +322,13 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { if err := s.KVSSet(7, update); err != nil { t.Fatalf("err: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Fetch the kv pair and check. - idx, result, err = s.KVSGet("foo") + ws = memdb.NewWatchSet() + idx, result, err = s.KVSGet(ws, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -323,11 +345,17 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { t.Fatalf("bad index: %d", idx) } + // Setting some unrelated key should not fire the watch. + testSetKey(t, s, 8, "other", "yup") + if watchFired(ws) { + t.Fatalf("bad") + } + // Fetch a key that doesn't exist and make sure we get the right // response. - idx, result, err = s.KVSGet("nope") - if result != nil || err != nil || idx != 7 { - t.Fatalf("expected (7, nil, nil), got : (%#v, %#v, %#v)", idx, result, err) + idx, result, err = s.KVSGet(nil, "nope") + if result != nil || err != nil || idx != 8 { + t.Fatalf("expected (8, nil, nil), got : (%#v, %#v, %#v)", idx, result, err) } } @@ -335,7 +363,8 @@ func TestStateStore_KVSList(t *testing.T) { s := testStateStore(t) // Listing an empty KVS returns nothing - idx, entries, err := s.KVSList("") + ws := memdb.NewWatchSet() + idx, entries, err := s.KVSList(ws, "") if idx != 0 || entries != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, entries, err) } @@ -346,9 +375,12 @@ func TestStateStore_KVSList(t *testing.T) { testSetKey(t, s, 3, "foo/bar/zip", "zip") testSetKey(t, s, 4, "foo/bar/zip/zorp", "zorp") testSetKey(t, s, 5, "foo/bar/baz", "baz") + if !watchFired(ws) { + t.Fatalf("bad") + } // List out all of the keys - idx, entries, err = s.KVSList("") + idx, entries, err = s.KVSList(nil, "") if err != nil { t.Fatalf("err: %s", err) } @@ -362,7 +394,7 @@ func TestStateStore_KVSList(t *testing.T) { } // Try listing with a provided prefix - idx, entries, err = s.KVSList("foo/bar/zip") + idx, entries, err = s.KVSList(nil, "foo/bar/zip") if err != nil { t.Fatalf("err: %s", err) } @@ -379,10 +411,19 @@ func TestStateStore_KVSList(t *testing.T) { } // Delete a key and make sure the index comes from the tombstone. + ws = memdb.NewWatchSet() + idx, _, err = s.KVSList(ws, "foo/bar/baz") + if err != nil { + t.Fatalf("err: %s", err) + } if err := s.KVSDelete(6, "foo/bar/baz"); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSList("foo/bar/baz") + if !watchFired(ws) { + t.Fatalf("bad") + } + ws = memdb.NewWatchSet() + idx, _, err = s.KVSList(ws, "foo/bar/baz") if err != nil { t.Fatalf("err: %s", err) } @@ -390,11 +431,15 @@ func TestStateStore_KVSList(t *testing.T) { t.Fatalf("bad index: %d", idx) } - // Set a different key to bump the index. + // Set a different key to bump the index. This shouldn't fire the + // watch since there's a different prefix. testSetKey(t, s, 7, "some/other/key", "") + if watchFired(ws) { + t.Fatalf("bad") + } // Make sure we get the right index from the tombstone. - idx, _, err = s.KVSList("foo/bar/baz") + idx, _, err = s.KVSList(nil, "foo/bar/baz") if err != nil { t.Fatalf("err: %s", err) } @@ -407,7 +452,7 @@ func TestStateStore_KVSList(t *testing.T) { if err := s.ReapTombstones(6); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSList("foo/bar/baz") + idx, _, err = s.KVSList(nil, "foo/bar/baz") if err != nil { t.Fatalf("err: %s", err) } @@ -416,7 +461,7 @@ func TestStateStore_KVSList(t *testing.T) { } // List all the keys to make sure the index is also correct. - idx, _, err = s.KVSList("") + idx, _, err = s.KVSList(nil, "") if err != nil { t.Fatalf("err: %s", err) } @@ -429,7 +474,8 @@ func TestStateStore_KVSListKeys(t *testing.T) { s := testStateStore(t) // Listing keys with no results returns nil. - idx, keys, err := s.KVSListKeys("", "") + ws := memdb.NewWatchSet() + idx, keys, err := s.KVSListKeys(ws, "", "") if idx != 0 || keys != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, keys, err) } @@ -442,9 +488,12 @@ func TestStateStore_KVSListKeys(t *testing.T) { testSetKey(t, s, 5, "foo/bar/zip/zam", "zam") testSetKey(t, s, 6, "foo/bar/zip/zorp", "zorp") testSetKey(t, s, 7, "some/other/prefix", "nack") + if !watchFired(ws) { + t.Fatalf("bad") + } // List all the keys. - idx, keys, err = s.KVSListKeys("", "") + idx, keys, err = s.KVSListKeys(nil, "", "") if err != nil { t.Fatalf("err: %s", err) } @@ -456,7 +505,7 @@ func TestStateStore_KVSListKeys(t *testing.T) { } // Query using a prefix and pass a separator. - idx, keys, err = s.KVSListKeys("foo/bar/", "/") + idx, keys, err = s.KVSListKeys(nil, "foo/bar/", "/") if err != nil { t.Fatalf("err: %s", err) } @@ -474,7 +523,7 @@ func TestStateStore_KVSListKeys(t *testing.T) { } // Listing keys with no separator returns everything. - idx, keys, err = s.KVSListKeys("foo", "") + idx, keys, err = s.KVSListKeys(nil, "foo", "") if err != nil { t.Fatalf("err: %s", err) } @@ -488,10 +537,19 @@ func TestStateStore_KVSListKeys(t *testing.T) { } // Delete a key and make sure the index comes from the tombstone. + ws = memdb.NewWatchSet() + idx, _, err = s.KVSListKeys(ws, "foo/bar/baz", "") + if err != nil { + t.Fatalf("err: %s", err) + } if err := s.KVSDelete(8, "foo/bar/baz"); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSListKeys("foo/bar/baz", "") + if !watchFired(ws) { + t.Fatalf("bad") + } + ws = memdb.NewWatchSet() + idx, _, err = s.KVSListKeys(ws, "foo/bar/baz", "") if err != nil { t.Fatalf("err: %s", err) } @@ -499,11 +557,15 @@ func TestStateStore_KVSListKeys(t *testing.T) { t.Fatalf("bad index: %d", idx) } - // Set a different key to bump the index. + // Set a different key to bump the index. This shouldn't fire the watch + // since there's a different prefix. testSetKey(t, s, 9, "some/other/key", "") + if watchFired(ws) { + t.Fatalf("bad") + } // Make sure the index still comes from the tombstone. - idx, _, err = s.KVSListKeys("foo/bar/baz", "") + idx, _, err = s.KVSListKeys(nil, "foo/bar/baz", "") if err != nil { t.Fatalf("err: %s", err) } @@ -516,7 +578,7 @@ func TestStateStore_KVSListKeys(t *testing.T) { if err := s.ReapTombstones(8); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSListKeys("foo/bar/baz", "") + idx, _, err = s.KVSListKeys(nil, "foo/bar/baz", "") if err != nil { t.Fatalf("err: %s", err) } @@ -525,7 +587,7 @@ func TestStateStore_KVSListKeys(t *testing.T) { } // List all the keys to make sure the index is also correct. - idx, _, err = s.KVSListKeys("", "") + idx, _, err = s.KVSListKeys(nil, "", "") if err != nil { t.Fatalf("err: %s", err) } @@ -573,7 +635,7 @@ func TestStateStore_KVSDelete(t *testing.T) { // Check that the tombstone was created and that prevents the index // from sliding backwards. - idx, _, err := s.KVSList("foo") + idx, _, err := s.KVSList(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -586,7 +648,7 @@ func TestStateStore_KVSDelete(t *testing.T) { if err := s.ReapTombstones(3); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSList("foo") + idx, _, err = s.KVSList(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -620,7 +682,7 @@ func TestStateStore_KVSDeleteCAS(t *testing.T) { // Check that the index is untouched and the entry // has not been deleted. - idx, e, err := s.KVSGet("foo") + idx, e, err := s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -639,7 +701,7 @@ func TestStateStore_KVSDeleteCAS(t *testing.T) { } // Entry was deleted and index was updated - idx, e, err = s.KVSGet("bar") + idx, e, err = s.KVSGet(nil, "bar") if err != nil { t.Fatalf("err: %s", err) } @@ -655,7 +717,7 @@ func TestStateStore_KVSDeleteCAS(t *testing.T) { // Check that the tombstone was created and that prevents the index // from sliding backwards. - idx, _, err = s.KVSList("bar") + idx, _, err = s.KVSList(nil, "bar") if err != nil { t.Fatalf("err: %s", err) } @@ -668,7 +730,7 @@ func TestStateStore_KVSDeleteCAS(t *testing.T) { if err := s.ReapTombstones(4); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSList("bar") + idx, _, err = s.KVSList(nil, "bar") if err != nil { t.Fatalf("err: %s", err) } @@ -733,7 +795,7 @@ func TestStateStore_KVSSetCAS(t *testing.T) { } // Entry was inserted - idx, entry, err := s.KVSGet("foo") + idx, entry, err := s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -775,7 +837,7 @@ func TestStateStore_KVSSetCAS(t *testing.T) { } // Entry was not updated in the store - idx, entry, err = s.KVSGet("foo") + idx, entry, err = s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -802,7 +864,7 @@ func TestStateStore_KVSSetCAS(t *testing.T) { } // Entry was updated - idx, entry, err = s.KVSGet("foo") + idx, entry, err = s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -829,7 +891,7 @@ func TestStateStore_KVSSetCAS(t *testing.T) { } // Entry was updated, but the session should have been ignored. - idx, entry, err = s.KVSGet("foo") + idx, entry, err = s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -874,7 +936,7 @@ func TestStateStore_KVSSetCAS(t *testing.T) { } // Entry was updated, and the lock status should have stayed the same. - idx, entry, err = s.KVSGet("foo") + idx, entry, err = s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -938,7 +1000,7 @@ func TestStateStore_KVSDeleteTree(t *testing.T) { // Check that the tombstones ware created and that prevents the index // from sliding backwards. - idx, _, err := s.KVSList("foo") + idx, _, err := s.KVSList(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -951,7 +1013,7 @@ func TestStateStore_KVSDeleteTree(t *testing.T) { if err := s.ReapTombstones(5); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSList("foo") + idx, _, err = s.KVSList(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -1000,7 +1062,7 @@ func TestStateStore_KVSLock(t *testing.T) { } // Make sure the indexes got set properly. - idx, result, err := s.KVSGet("foo") + idx, result, err := s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -1021,7 +1083,7 @@ func TestStateStore_KVSLock(t *testing.T) { // Make sure the indexes got set properly, note that the lock index // won't go up since we didn't lock it again. - idx, result, err = s.KVSGet("foo") + idx, result, err = s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -1044,7 +1106,7 @@ func TestStateStore_KVSLock(t *testing.T) { } // Make sure the indexes got set properly. - idx, result, err = s.KVSGet("foo") + idx, result, err = s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -1064,7 +1126,7 @@ func TestStateStore_KVSLock(t *testing.T) { } // Make sure the indexes got set properly. - idx, result, err = s.KVSGet("bar") + idx, result, err = s.KVSGet(nil, "bar") if err != nil { t.Fatalf("err: %s", err) } @@ -1090,7 +1152,7 @@ func TestStateStore_KVSLock(t *testing.T) { } // Make sure the indexes didn't update. - idx, result, err = s.KVSGet("bar") + idx, result, err = s.KVSGet(nil, "bar") if err != nil { t.Fatalf("err: %s", err) } @@ -1134,7 +1196,7 @@ func TestStateStore_KVSUnlock(t *testing.T) { } // Make sure the indexes didn't update. - idx, result, err := s.KVSGet("foo") + idx, result, err := s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -1163,7 +1225,7 @@ func TestStateStore_KVSUnlock(t *testing.T) { } // Make sure the indexes didn't update. - idx, result, err = s.KVSGet("foo") + idx, result, err = s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -1182,7 +1244,7 @@ func TestStateStore_KVSUnlock(t *testing.T) { } // Make sure the indexes got set properly. - idx, result, err = s.KVSGet("foo") + idx, result, err = s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -1201,7 +1263,7 @@ func TestStateStore_KVSUnlock(t *testing.T) { } // Make sure the indexes didn't update. - idx, result, err = s.KVSGet("foo") + idx, result, err = s.KVSGet(nil, "foo") if err != nil { t.Fatalf("err: %s", err) } @@ -1294,7 +1356,7 @@ func TestStateStore_KVS_Snapshot_Restore(t *testing.T) { restore.Commit() // Read the restored keys back out and verify they match. - idx, res, err := s.KVSList("") + idx, res, err := s.KVSList(nil, "") if err != nil { t.Fatalf("err: %s", err) } @@ -1312,142 +1374,6 @@ func TestStateStore_KVS_Snapshot_Restore(t *testing.T) { }() } -func TestStateStore_KVS_Watches(t *testing.T) { - s := testStateStore(t) - - // This is used when locking down below. - testRegisterNode(t, s, 1, "node1") - session := testUUID() - if err := s.SessionCreate(2, &structs.Session{ID: session, Node: "node1"}); err != nil { - t.Fatalf("err: %s", err) - } - - // An empty prefix watch should hit on all KVS ops, and some other - // prefix should not be affected ever. We also add a positive prefix - // match. - verifyWatch(t, s.GetKVSWatch(""), func() { - verifyWatch(t, s.GetKVSWatch("a"), func() { - verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { - if err := s.KVSSet(1, &structs.DirEntry{Key: "aaa"}); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - verifyWatch(t, s.GetKVSWatch(""), func() { - verifyWatch(t, s.GetKVSWatch("a"), func() { - verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { - if err := s.KVSSet(2, &structs.DirEntry{Key: "aaa"}); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - - // Restore just fires off a top-level watch, so we should get hits on - // any prefix, including ones for keys that aren't in there. - verifyWatch(t, s.GetKVSWatch(""), func() { - verifyWatch(t, s.GetKVSWatch("b"), func() { - verifyWatch(t, s.GetKVSWatch("/nope"), func() { - restore := s.Restore() - if err := restore.KVS(&structs.DirEntry{Key: "bbb"}); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - }) - }) - }) - - verifyWatch(t, s.GetKVSWatch(""), func() { - verifyWatch(t, s.GetKVSWatch("a"), func() { - verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { - if err := s.KVSDelete(3, "aaa"); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - verifyWatch(t, s.GetKVSWatch(""), func() { - verifyWatch(t, s.GetKVSWatch("a"), func() { - verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { - if ok, err := s.KVSSetCAS(4, &structs.DirEntry{Key: "aaa"}); !ok || err != nil { - t.Fatalf("ok: %v err: %s", ok, err) - } - }) - }) - }) - verifyWatch(t, s.GetKVSWatch(""), func() { - verifyWatch(t, s.GetKVSWatch("a"), func() { - verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { - if ok, err := s.KVSLock(5, &structs.DirEntry{Key: "aaa", Session: session}); !ok || err != nil { - t.Fatalf("ok: %v err: %s", ok, err) - } - }) - }) - }) - verifyWatch(t, s.GetKVSWatch(""), func() { - verifyWatch(t, s.GetKVSWatch("a"), func() { - verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { - if ok, err := s.KVSUnlock(6, &structs.DirEntry{Key: "aaa", Session: session}); !ok || err != nil { - t.Fatalf("ok: %v err: %s", ok, err) - } - }) - }) - }) - verifyWatch(t, s.GetKVSWatch(""), func() { - verifyWatch(t, s.GetKVSWatch("a"), func() { - verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { - if err := s.KVSDeleteTree(7, "aaa"); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - - // A delete tree operation at the top level will notify all the watches. - verifyWatch(t, s.GetKVSWatch(""), func() { - verifyWatch(t, s.GetKVSWatch("a"), func() { - verifyWatch(t, s.GetKVSWatch("/nope"), func() { - if err := s.KVSDeleteTree(8, ""); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - - // Create a more interesting tree. - testSetKey(t, s, 9, "foo/bar", "bar") - testSetKey(t, s, 10, "foo/bar/baz", "baz") - testSetKey(t, s, 11, "foo/bar/zip", "zip") - testSetKey(t, s, 12, "foo/zorp", "zorp") - - // Deleting just the foo/bar key should not trigger watches on the - // children. - verifyWatch(t, s.GetKVSWatch("foo/bar"), func() { - verifyNoWatch(t, s.GetKVSWatch("foo/bar/baz"), func() { - verifyNoWatch(t, s.GetKVSWatch("foo/bar/zip"), func() { - if err := s.KVSDelete(13, "foo/bar"); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - - // But a delete tree from that point should notify the whole subtree, - // even for keys that don't exist. - verifyWatch(t, s.GetKVSWatch("foo/bar"), func() { - verifyWatch(t, s.GetKVSWatch("foo/bar/baz"), func() { - verifyWatch(t, s.GetKVSWatch("foo/bar/zip"), func() { - verifyWatch(t, s.GetKVSWatch("foo/bar/uh/nope"), func() { - if err := s.KVSDeleteTree(14, "foo/bar"); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - }) -} - func TestStateStore_Tombstone_Snapshot_Restore(t *testing.T) { s := testStateStore(t) @@ -1467,7 +1393,7 @@ func TestStateStore_Tombstone_Snapshot_Restore(t *testing.T) { if err := s.ReapTombstones(4); err != nil { t.Fatalf("err: %s", err) } - idx, _, err := s.KVSList("foo/bar") + idx, _, err := s.KVSList(nil, "foo/bar") if err != nil { t.Fatalf("err: %s", err) } @@ -1504,7 +1430,7 @@ func TestStateStore_Tombstone_Snapshot_Restore(t *testing.T) { restore.Commit() // See if the stone works properly in a list query. - idx, _, err := s.KVSList("foo/bar") + idx, _, err := s.KVSList(nil, "foo/bar") if err != nil { t.Fatalf("err: %s", err) } @@ -1518,7 +1444,7 @@ func TestStateStore_Tombstone_Snapshot_Restore(t *testing.T) { if err := s.ReapTombstones(4); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSList("foo/bar") + idx, _, err = s.KVSList(nil, "foo/bar") if err != nil { t.Fatalf("err: %s", err) } diff --git a/consul/state/prepared_query.go b/consul/state/prepared_query.go index c84496fbdd..c9c4f7e13a 100644 --- a/consul/state/prepared_query.go +++ b/consul/state/prepared_query.go @@ -75,7 +75,6 @@ func (s *StateRestore) PreparedQuery(query *structs.PreparedQuery) error { return fmt.Errorf("failed updating index: %s", err) } - s.watches.Arm("prepared-queries") return nil } @@ -193,7 +192,6 @@ func (s *StateStore) preparedQuerySetTxn(tx *memdb.Txn, idx uint64, query *struc return fmt.Errorf("failed updating index: %s", err) } - tx.Defer(func() { s.tableWatches["prepared-queries"].Notify() }) return nil } @@ -202,20 +200,17 @@ func (s *StateStore) PreparedQueryDelete(idx uint64, queryID string) error { tx := s.db.Txn(true) defer tx.Abort() - watches := NewDumbWatchManager(s.tableWatches) - if err := s.preparedQueryDeleteTxn(tx, idx, watches, queryID); err != nil { + if err := s.preparedQueryDeleteTxn(tx, idx, queryID); err != nil { return fmt.Errorf("failed prepared query delete: %s", err) } - tx.Defer(func() { watches.Notify() }) tx.Commit() return nil } // preparedQueryDeleteTxn is the inner method used to delete a prepared query // with the proper indexes into the state store. -func (s *StateStore) preparedQueryDeleteTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, - queryID string) error { +func (s *StateStore) preparedQueryDeleteTxn(tx *memdb.Txn, idx uint64, queryID string) error { // Pull the query. wrapped, err := tx.First("prepared-queries", "id", queryID) if err != nil { @@ -233,23 +228,23 @@ func (s *StateStore) preparedQueryDeleteTxn(tx *memdb.Txn, idx uint64, watches * return fmt.Errorf("failed updating index: %s", err) } - watches.Arm("prepared-queries") return nil } // PreparedQueryGet returns the given prepared query by ID. -func (s *StateStore) PreparedQueryGet(queryID string) (uint64, *structs.PreparedQuery, error) { +func (s *StateStore) PreparedQueryGet(ws memdb.WatchSet, queryID string) (uint64, *structs.PreparedQuery, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("PreparedQueryGet")...) + idx := maxIndexTxn(tx, "prepared-queries") // Look up the query by its ID. - wrapped, err := tx.First("prepared-queries", "id", queryID) + watchCh, wrapped, err := tx.FirstWatch("prepared-queries", "id", queryID) if err != nil { return 0, nil, fmt.Errorf("failed prepared query lookup: %s", err) } + ws.Add(watchCh) return idx, toPreparedQuery(wrapped), nil } @@ -261,7 +256,7 @@ func (s *StateStore) PreparedQueryResolve(queryIDOrName string) (uint64, *struct defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("PreparedQueryResolve")...) + idx := maxIndexTxn(tx, "prepared-queries") // Explicitly ban an empty query. This will never match an ID and the // schema is set up so it will never match a query with an empty name, @@ -331,18 +326,19 @@ func (s *StateStore) PreparedQueryResolve(queryIDOrName string) (uint64, *struct } // PreparedQueryList returns all the prepared queries. -func (s *StateStore) PreparedQueryList() (uint64, structs.PreparedQueries, error) { +func (s *StateStore) PreparedQueryList(ws memdb.WatchSet) (uint64, structs.PreparedQueries, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("PreparedQueryList")...) + idx := maxIndexTxn(tx, "prepared-queries") // Query all of the prepared queries in the state store. queries, err := tx.Get("prepared-queries", "id") if err != nil { return 0, nil, fmt.Errorf("failed prepared query lookup: %s", err) } + ws.Add(queries.WatchCh()) // Go over all of the queries and build the response. var result structs.PreparedQueries diff --git a/consul/state/prepared_query_test.go b/consul/state/prepared_query_test.go index c0581986be..b42bde8e79 100644 --- a/consul/state/prepared_query_test.go +++ b/consul/state/prepared_query_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" ) func TestStateStore_PreparedQuery_isUUID(t *testing.T) { @@ -37,7 +38,8 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { s := testStateStore(t) // Querying with no results returns nil. - idx, res, err := s.PreparedQueryGet(testUUID()) + ws := memdb.NewWatchSet() + idx, res, err := s.PreparedQueryGet(ws, testUUID()) if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } @@ -51,6 +53,9 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { if idx := s.maxIndex("prepared-queries"); idx != 0 { t.Fatalf("bad index: %d", idx) } + if watchFired(ws) { + t.Fatalf("bad") + } // Build a legit-looking query with the most basic options. query := &structs.PreparedQuery{ @@ -71,6 +76,9 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { if idx := s.maxIndex("prepared-queries"); idx != 0 { t.Fatalf("bad index: %d", idx) } + if watchFired(ws) { + t.Fatalf("bad") + } // Now register the service and remove the bogus session. testRegisterNode(t, s, 1, "foo") @@ -86,6 +94,9 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { if idx := s.maxIndex("prepared-queries"); idx != 3 { t.Fatalf("bad index: %d", idx) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Read it back out and verify it. expected := &structs.PreparedQuery{ @@ -98,7 +109,8 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { ModifyIndex: 3, }, } - idx, actual, err := s.PreparedQueryGet(query.ID) + ws = memdb.NewWatchSet() + idx, actual, err := s.PreparedQueryGet(ws, query.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -119,11 +131,15 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { if idx := s.maxIndex("prepared-queries"); idx != 4 { t.Fatalf("bad index: %d", idx) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Read it back and verify the data was updated as well as the index. expected.Name = "test-query" expected.ModifyIndex = 4 - idx, actual, err = s.PreparedQueryGet(query.ID) + ws = memdb.NewWatchSet() + idx, actual, err = s.PreparedQueryGet(ws, query.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -145,6 +161,9 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { if idx := s.maxIndex("prepared-queries"); idx != 4 { t.Fatalf("bad index: %d", idx) } + if watchFired(ws) { + t.Fatalf("bad") + } // Now make a session and try again. session := &structs.Session{ @@ -162,11 +181,15 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { if idx := s.maxIndex("prepared-queries"); idx != 6 { t.Fatalf("bad index: %d", idx) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Read it back and verify the data was updated as well as the index. expected.Session = query.Session expected.ModifyIndex = 6 - idx, actual, err = s.PreparedQueryGet(query.ID) + ws = memdb.NewWatchSet() + idx, actual, err = s.PreparedQueryGet(ws, query.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -192,7 +215,7 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { } // Sanity check to make sure it's not there. - idx, actual, err := s.PreparedQueryGet(evil.ID) + idx, actual, err := s.PreparedQueryGet(nil, evil.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -220,7 +243,7 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { } // Sanity check to make sure it's not there. - idx, actual, err := s.PreparedQueryGet(evil.ID) + idx, actual, err := s.PreparedQueryGet(nil, evil.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -250,7 +273,7 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { } // Sanity check to make sure it's not there. - idx, actual, err := s.PreparedQueryGet(evil.ID) + idx, actual, err := s.PreparedQueryGet(nil, evil.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -266,6 +289,9 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { if idx := s.maxIndex("prepared-queries"); idx != 6 { t.Fatalf("bad index: %d", idx) } + if watchFired(ws) { + t.Fatalf("bad") + } // Turn the query into a template with an empty name. query.Name = "" @@ -280,6 +306,9 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { if idx := s.maxIndex("prepared-queries"); idx != 9 { t.Fatalf("bad index: %d", idx) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Read it back and verify the data was updated as well as the index. expected.Name = "" @@ -287,7 +316,8 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { Type: structs.QueryTemplateTypeNamePrefixMatch, } expected.ModifyIndex = 9 - idx, actual, err = s.PreparedQueryGet(query.ID) + ws = memdb.NewWatchSet() + idx, actual, err = s.PreparedQueryGet(ws, query.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -316,7 +346,7 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { } // Sanity check to make sure it's not there. - idx, actual, err := s.PreparedQueryGet(evil.ID) + idx, actual, err := s.PreparedQueryGet(nil, evil.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -338,11 +368,15 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { if idx := s.maxIndex("prepared-queries"); idx != 11 { t.Fatalf("bad index: %d", idx) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Read it back and verify the data was updated as well as the index. expected.Name = "prefix" expected.ModifyIndex = 11 - idx, actual, err = s.PreparedQueryGet(query.ID) + ws = memdb.NewWatchSet() + idx, actual, err = s.PreparedQueryGet(ws, query.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -371,7 +405,7 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { } // Sanity check to make sure it's not there. - idx, actual, err := s.PreparedQueryGet(evil.ID) + idx, actual, err := s.PreparedQueryGet(nil, evil.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -401,7 +435,7 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { } // Sanity check to make sure it's not there. - idx, actual, err := s.PreparedQueryGet(evil.ID) + idx, actual, err := s.PreparedQueryGet(nil, evil.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -412,6 +446,10 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { t.Fatalf("bad: %v", actual) } } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_PreparedQueryDelete(t *testing.T) { @@ -460,7 +498,8 @@ func TestStateStore_PreparedQueryDelete(t *testing.T) { ModifyIndex: 3, }, } - idx, actual, err := s.PreparedQueryGet(query.ID) + ws := memdb.NewWatchSet() + idx, actual, err := s.PreparedQueryGet(ws, query.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -480,9 +519,12 @@ func TestStateStore_PreparedQueryDelete(t *testing.T) { if idx := s.maxIndex("prepared-queries"); idx != 4 { t.Fatalf("bad index: %d", idx) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Sanity check to make sure it's not there. - idx, actual, err = s.PreparedQueryGet(query.ID) + idx, actual, err = s.PreparedQueryGet(nil, query.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -716,7 +758,8 @@ func TestStateStore_PreparedQueryList(t *testing.T) { s := testStateStore(t) // Make sure nothing is returned for an empty query - idx, actual, err := s.PreparedQueryList() + ws := memdb.NewWatchSet() + idx, actual, err := s.PreparedQueryList(ws) if err != nil { t.Fatalf("err: %s", err) } @@ -761,6 +804,9 @@ func TestStateStore_PreparedQueryList(t *testing.T) { t.Fatalf("err: %s", err) } } + if !watchFired(ws) { + t.Fatalf("bad") + } // Read it back and verify. expected := structs.PreparedQueries{ @@ -787,7 +833,7 @@ func TestStateStore_PreparedQueryList(t *testing.T) { }, }, } - idx, actual, err = s.PreparedQueryList() + idx, actual, err = s.PreparedQueryList(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -901,7 +947,7 @@ func TestStateStore_PreparedQuery_Snapshot_Restore(t *testing.T) { // Read the restored queries back out and verify that they // match. - idx, actual, err := s.PreparedQueryList() + idx, actual, err := s.PreparedQueryList(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -926,38 +972,3 @@ func TestStateStore_PreparedQuery_Snapshot_Restore(t *testing.T) { } }() } - -func TestStateStore_PreparedQuery_Watches(t *testing.T) { - s := testStateStore(t) - - // Set up our test environment. - testRegisterNode(t, s, 1, "foo") - testRegisterService(t, s, 2, "foo", "redis") - - query := &structs.PreparedQuery{ - ID: testUUID(), - Service: structs.ServiceQuery{ - Service: "redis", - }, - } - - // Call functions that update the queries table and make sure a watch - // fires each time. - verifyWatch(t, s.getTableWatch("prepared-queries"), func() { - if err := s.PreparedQuerySet(3, query); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("prepared-queries"), func() { - if err := s.PreparedQueryDelete(4, query.ID); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("prepared-queries"), func() { - restore := s.Restore() - if err := restore.PreparedQuery(query); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - }) -} diff --git a/consul/state/schema.go b/consul/state/schema.go index 3d3ed14503..cf701d64e6 100644 --- a/consul/state/schema.go +++ b/consul/state/schema.go @@ -188,6 +188,22 @@ func checksTableSchema() *memdb.TableSchema { Lowercase: true, }, }, + "node_service_check": &memdb.IndexSchema{ + Name: "node_service_check", + AllowMissing: true, + Unique: false, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.FieldSetIndex{ + Field: "ServiceID", + }, + }, + }, + }, "node_service": &memdb.IndexSchema{ Name: "node_service", AllowMissing: true, diff --git a/consul/state/session.go b/consul/state/session.go index 08e6c521df..78442146dd 100644 --- a/consul/state/session.go +++ b/consul/state/session.go @@ -42,7 +42,6 @@ func (s *StateRestore) Session(sess *structs.Session) error { return fmt.Errorf("failed updating index: %s", err) } - s.watches.Arm("sessions") return nil } @@ -140,23 +139,23 @@ func (s *StateStore) sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.S return fmt.Errorf("failed updating index: %s", err) } - tx.Defer(func() { s.tableWatches["sessions"].Notify() }) return nil } // SessionGet is used to retrieve an active session from the state store. -func (s *StateStore) SessionGet(sessionID string) (uint64, *structs.Session, error) { +func (s *StateStore) SessionGet(ws memdb.WatchSet, sessionID string) (uint64, *structs.Session, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("SessionGet")...) + idx := maxIndexTxn(tx, "sessions") // Look up the session by its ID - session, err := tx.First("sessions", "id", sessionID) + watchCh, session, err := tx.FirstWatch("sessions", "id", sessionID) if err != nil { return 0, nil, fmt.Errorf("failed session lookup: %s", err) } + ws.Add(watchCh) if session != nil { return idx, session.(*structs.Session), nil } @@ -164,18 +163,19 @@ func (s *StateStore) SessionGet(sessionID string) (uint64, *structs.Session, err } // SessionList returns a slice containing all of the active sessions. -func (s *StateStore) SessionList() (uint64, structs.Sessions, error) { +func (s *StateStore) SessionList(ws memdb.WatchSet) (uint64, structs.Sessions, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("SessionList")...) + idx := maxIndexTxn(tx, "sessions") // Query all of the active sessions. sessions, err := tx.Get("sessions", "id") if err != nil { return 0, nil, fmt.Errorf("failed session lookup: %s", err) } + ws.Add(sessions.WatchCh()) // Go over the sessions and create a slice of them. var result structs.Sessions @@ -188,18 +188,19 @@ func (s *StateStore) SessionList() (uint64, structs.Sessions, error) { // NodeSessions returns a set of active sessions associated // with the given node ID. The returned index is the highest // index seen from the result set. -func (s *StateStore) NodeSessions(nodeID string) (uint64, structs.Sessions, error) { +func (s *StateStore) NodeSessions(ws memdb.WatchSet, nodeID string) (uint64, structs.Sessions, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeSessions")...) + idx := maxIndexTxn(tx, "sessions") // Get all of the sessions which belong to the node sessions, err := tx.Get("sessions", "node", nodeID) if err != nil { return 0, nil, fmt.Errorf("failed session lookup: %s", err) } + ws.Add(sessions.WatchCh()) // Go over all of the sessions and return them as a slice var result structs.Sessions @@ -217,19 +218,17 @@ func (s *StateStore) SessionDestroy(idx uint64, sessionID string) error { defer tx.Abort() // Call the session deletion. - watches := NewDumbWatchManager(s.tableWatches) - if err := s.deleteSessionTxn(tx, idx, watches, sessionID); err != nil { + if err := s.deleteSessionTxn(tx, idx, sessionID); err != nil { return err } - tx.Defer(func() { watches.Notify() }) tx.Commit() return nil } // deleteSessionTxn is the inner method, which is used to do the actual -// session deletion and handle session invalidation, watch triggers, etc. -func (s *StateStore) deleteSessionTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, sessionID string) error { +// session deletion and handle session invalidation, etc. +func (s *StateStore) deleteSessionTxn(tx *memdb.Txn, idx uint64, sessionID string) error { // Look up the session. sess, err := tx.First("sessions", "id", sessionID) if err != nil { @@ -334,12 +333,11 @@ func (s *StateStore) deleteSessionTxn(tx *memdb.Txn, idx uint64, watches *DumbWa // Do the delete in a separate loop so we don't trash the iterator. for _, id := range ids { - if err := s.preparedQueryDeleteTxn(tx, idx, watches, id); err != nil { + if err := s.preparedQueryDeleteTxn(tx, idx, id); err != nil { return fmt.Errorf("failed prepared query delete: %s", err) } } } - watches.Arm("sessions") return nil } diff --git a/consul/state/session_test.go b/consul/state/session_test.go index 3e435e7e16..e3280ae167 100644 --- a/consul/state/session_test.go +++ b/consul/state/session_test.go @@ -9,13 +9,15 @@ import ( "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/types" + "github.com/hashicorp/go-memdb" ) func TestStateStore_SessionCreate_SessionGet(t *testing.T) { s := testStateStore(t) // SessionGet returns nil if the session doesn't exist - idx, session, err := s.SessionGet(testUUID()) + ws := memdb.NewWatchSet() + idx, session, err := s.SessionGet(ws, testUUID()) if session != nil || err != nil { t.Fatalf("expected (nil, nil), got: (%#v, %#v)", session, err) } @@ -49,6 +51,9 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { if idx := s.maxIndex("sessions"); idx != 0 { t.Fatalf("bad index: %d", idx) } + if watchFired(ws) { + t.Fatalf("bad") + } // Valid session is able to register testRegisterNode(t, s, 1, "node1") @@ -62,9 +67,13 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { if idx := s.maxIndex("sessions"); idx != 2 { t.Fatalf("bad index: %s", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } // Retrieve the session again - idx, session, err = s.SessionGet(sess.ID) + ws = memdb.NewWatchSet() + idx, session, err = s.SessionGet(ws, sess.ID) if err != nil { t.Fatalf("err: %s", err) } @@ -104,12 +113,19 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { if err == nil || !strings.Contains(err.Error(), structs.HealthCritical) { t.Fatalf("expected critical state error, got: %#v", err) } + if watchFired(ws) { + t.Fatalf("bad") + } - // Registering with a healthy check succeeds + // Registering with a healthy check succeeds (doesn't hit the watch since + // we are looking at the old session). testRegisterCheck(t, s, 4, "node1", "", "check1", structs.HealthPassing) if err := s.SessionCreate(5, sess); err != nil { t.Fatalf("err: %s", err) } + if watchFired(ws) { + t.Fatalf("bad") + } // Register a session against two checks. testRegisterCheck(t, s, 5, "node1", "", "check2", structs.HealthPassing) @@ -159,7 +175,7 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { } // Pulling a nonexistent session gives the table index. - idx, session, err = s.SessionGet(testUUID()) + idx, session, err = s.SessionGet(nil, testUUID()) if err != nil { t.Fatalf("err: %s", err) } @@ -175,7 +191,8 @@ func TegstStateStore_SessionList(t *testing.T) { s := testStateStore(t) // Listing when no sessions exist returns nil - idx, res, err := s.SessionList() + ws := memdb.NewWatchSet() + idx, res, err := s.SessionList(ws) if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } @@ -208,9 +225,12 @@ func TegstStateStore_SessionList(t *testing.T) { t.Fatalf("err: %s", err) } } + if !watchFired(ws) { + t.Fatalf("bad") + } // List out all of the sessions - idx, sessionList, err := s.SessionList() + idx, sessionList, err := s.SessionList(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -226,7 +246,8 @@ func TestStateStore_NodeSessions(t *testing.T) { s := testStateStore(t) // Listing sessions with no results returns nil - idx, res, err := s.NodeSessions("node1") + ws := memdb.NewWatchSet() + idx, res, err := s.NodeSessions(ws, "node1") if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } @@ -261,10 +282,14 @@ func TestStateStore_NodeSessions(t *testing.T) { t.Fatalf("err: %s", err) } } + if !watchFired(ws) { + t.Fatalf("bad") + } // Query all of the sessions associated with a specific // node in the state store. - idx, res, err = s.NodeSessions("node1") + ws1 := memdb.NewWatchSet() + idx, res, err = s.NodeSessions(ws1, "node1") if err != nil { t.Fatalf("err: %s", err) } @@ -275,7 +300,8 @@ func TestStateStore_NodeSessions(t *testing.T) { t.Fatalf("bad index: %d", idx) } - idx, res, err = s.NodeSessions("node2") + ws2 := memdb.NewWatchSet() + idx, res, err = s.NodeSessions(ws2, "node2") if err != nil { t.Fatalf("err: %s", err) } @@ -285,6 +311,17 @@ func TestStateStore_NodeSessions(t *testing.T) { if idx != 6 { t.Fatalf("bad index: %d", idx) } + + // Destroying a session on node1 should not affect node2's watch. + if err := s.SessionDestroy(100, sessions1[0].ID); err != nil { + t.Fatalf("err: %s", err) + } + if !watchFired(ws1) { + t.Fatalf("bad") + } + if watchFired(ws2) { + t.Fatalf("bad") + } } func TestStateStore_SessionDestroy(t *testing.T) { @@ -418,7 +455,7 @@ func TestStateStore_Session_Snapshot_Restore(t *testing.T) { // Read the restored sessions back out and verify that they // match. - idx, res, err := s.SessionList() + idx, res, err := s.SessionList(nil) if err != nil { t.Fatalf("err: %s", err) } @@ -467,44 +504,6 @@ func TestStateStore_Session_Snapshot_Restore(t *testing.T) { }() } -func TestStateStore_Session_Watches(t *testing.T) { - s := testStateStore(t) - - // Register a test node. - testRegisterNode(t, s, 1, "node1") - - // This just covers the basics. The session invalidation tests above - // cover the more nuanced multiple table watches. - session := testUUID() - verifyWatch(t, s.getTableWatch("sessions"), func() { - sess := &structs.Session{ - ID: session, - Node: "node1", - Behavior: structs.SessionKeysDelete, - } - if err := s.SessionCreate(2, sess); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("sessions"), func() { - if err := s.SessionDestroy(3, session); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("sessions"), func() { - restore := s.Restore() - sess := &structs.Session{ - ID: session, - Node: "node1", - Behavior: structs.SessionKeysDelete, - } - if err := restore.Session(sess); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - }) -} - func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) { s := testStateStore(t) @@ -520,17 +519,21 @@ func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) { t.Fatalf("err: %v", err) } - // Delete the node and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("nodes"), func() { - if err := s.DeleteNode(15, "foo"); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) + // Delete the node and make sure the watch fires. + ws := memdb.NewWatchSet() + idx, s2, err := s.SessionGet(ws, session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if err := s.DeleteNode(15, "foo"); err != nil { + t.Fatalf("err: %v", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } // Lookup by ID, should be nil. - idx, s2, err := s.SessionGet(session.ID) + idx, s2, err = s.SessionGet(nil, session.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -571,19 +574,21 @@ func TestStateStore_Session_Invalidate_DeleteService(t *testing.T) { t.Fatalf("err: %v", err) } - // Delete the service and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("services"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - if err := s.DeleteService(15, "foo", "api"); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) - }) + // Delete the service and make sure the watch fires. + ws := memdb.NewWatchSet() + idx, s2, err := s.SessionGet(ws, session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if err := s.DeleteService(15, "foo", "api"); err != nil { + t.Fatalf("err: %v", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } // Lookup by ID, should be nil. - idx, s2, err := s.SessionGet(session.ID) + idx, s2, err = s.SessionGet(nil, session.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -620,17 +625,21 @@ func TestStateStore_Session_Invalidate_Critical_Check(t *testing.T) { } // Invalidate the check and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - check.Status = structs.HealthCritical - if err := s.EnsureCheck(15, check); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) + ws := memdb.NewWatchSet() + idx, s2, err := s.SessionGet(ws, session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + check.Status = structs.HealthCritical + if err := s.EnsureCheck(15, check); err != nil { + t.Fatalf("err: %v", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } // Lookup by ID, should be nil. - idx, s2, err := s.SessionGet(session.ID) + idx, s2, err = s.SessionGet(nil, session.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -667,16 +676,20 @@ func TestStateStore_Session_Invalidate_DeleteCheck(t *testing.T) { } // Delete the check and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - if err := s.DeleteCheck(15, "foo", "bar"); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) + ws := memdb.NewWatchSet() + idx, s2, err := s.SessionGet(ws, session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if err := s.DeleteCheck(15, "foo", "bar"); err != nil { + t.Fatalf("err: %v", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } // Lookup by ID, should be nil. - idx, s2, err := s.SessionGet(session.ID) + idx, s2, err = s.SessionGet(nil, session.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -731,18 +744,20 @@ func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) { } // Delete the node and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.GetKVSWatch("/f"), func() { - if err := s.DeleteNode(6, "foo"); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) - }) + ws := memdb.NewWatchSet() + idx, s2, err := s.SessionGet(ws, session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if err := s.DeleteNode(6, "foo"); err != nil { + t.Fatalf("err: %v", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } // Lookup by ID, should be nil. - idx, s2, err := s.SessionGet(session.ID) + idx, s2, err = s.SessionGet(nil, session.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -754,7 +769,7 @@ func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) { } // Key should be unlocked. - idx, d2, err := s.KVSGet("/foo") + idx, d2, err := s.KVSGet(nil, "/foo") if err != nil { t.Fatalf("err: %s", err) } @@ -811,18 +826,20 @@ func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) { } // Delete the node and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.GetKVSWatch("/b"), func() { - if err := s.DeleteNode(6, "foo"); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) - }) + ws := memdb.NewWatchSet() + idx, s2, err := s.SessionGet(ws, session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if err := s.DeleteNode(6, "foo"); err != nil { + t.Fatalf("err: %v", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } // Lookup by ID, should be nil. - idx, s2, err := s.SessionGet(session.ID) + idx, s2, err = s.SessionGet(nil, session.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -834,7 +851,7 @@ func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) { } // Key should be deleted. - idx, d2, err := s.KVSGet("/bar") + idx, d2, err := s.KVSGet(nil, "/bar") if err != nil { t.Fatalf("err: %s", err) } @@ -877,16 +894,20 @@ func TestStateStore_Session_Invalidate_PreparedQuery_Delete(t *testing.T) { } // Invalidate the session and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("prepared-queries"), func() { - if err := s.SessionDestroy(5, session.ID); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) + ws := memdb.NewWatchSet() + idx, s2, err := s.SessionGet(ws, session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if err := s.SessionDestroy(5, session.ID); err != nil { + t.Fatalf("err: %v", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } // Make sure the session is gone. - idx, s2, err := s.SessionGet(session.ID) + idx, s2, err = s.SessionGet(nil, session.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -898,7 +919,7 @@ func TestStateStore_Session_Invalidate_PreparedQuery_Delete(t *testing.T) { } // Make sure the query is gone and the index is updated. - idx, q2, err := s.PreparedQueryGet(query.ID) + idx, q2, err := s.PreparedQueryGet(nil, query.ID) if err != nil { t.Fatalf("err: %s", err) } diff --git a/consul/state/state_store.go b/consul/state/state_store.go index dc72726564..06481db031 100644 --- a/consul/state/state_store.go +++ b/consul/state/state_store.go @@ -30,6 +30,18 @@ var ( ErrMissingQueryID = errors.New("Missing Query ID") ) +const ( + // watchLimit is used as a soft limit to cap how many watches we allow + // for a given blocking query. If this is exceeded, then we will use a + // higher-level watch that's less fine-grained. This isn't as bad as it + // seems since we have made the main culprits (nodes and services) more + // efficient by diffing before we update via register requests. + // + // Given the current size of aFew == 32 in memdb's watch_few.go, this + // will allow for up to ~64 goroutines per blocking query. + watchLimit = 2048 +) + // StateStore is where we store all of Consul's state, including // records of node registrations, services, checks, key/value // pairs and more. The DB is entirely in-memory and is constructed @@ -38,11 +50,9 @@ type StateStore struct { schema *memdb.DBSchema db *memdb.MemDB - // tableWatches holds all the full table watches, indexed by table name. - tableWatches map[string]*FullTableWatch - - // kvsWatch holds the special prefix watch for the key value store. - kvsWatch *PrefixWatchManager + // abandonCh is used to signal watchers that this state store has been + // abandoned (usually during a restore). This is only ever closed. + abandonCh chan struct{} // kvsGraveyard manages tombstones for the key value store. kvsGraveyard *Graveyard @@ -62,9 +72,8 @@ type StateSnapshot struct { // StateRestore is used to efficiently manage restoring a large amount of // data to a state store. type StateRestore struct { - store *StateStore - tx *memdb.Txn - watches *DumbWatchManager + store *StateStore + tx *memdb.Txn } // IndexEntry keeps a record of the last index per-table. @@ -92,22 +101,11 @@ func NewStateStore(gc *TombstoneGC) (*StateStore, error) { return nil, fmt.Errorf("Failed setting up state store: %s", err) } - // Build up the all-table watches. - tableWatches := make(map[string]*FullTableWatch) - for table, _ := range schema.Tables { - if table == "kvs" || table == "tombstones" { - continue - } - - tableWatches[table] = NewFullTableWatch() - } - // Create and return the state store. s := &StateStore{ schema: schema, db: db, - tableWatches: tableWatches, - kvsWatch: NewPrefixWatchManager(), + abandonCh: make(chan struct{}), kvsGraveyard: NewGraveyard(gc), lockDelay: NewDelay(), } @@ -142,8 +140,7 @@ func (s *StateSnapshot) Close() { // transaction. func (s *StateStore) Restore() *StateRestore { tx := s.db.Txn(true) - watches := NewDumbWatchManager(s.tableWatches) - return &StateRestore{s, tx, watches} + return &StateRestore{s, tx} } // Abort abandons the changes made by a restore. This or Commit should always be @@ -155,14 +152,21 @@ func (s *StateRestore) Abort() { // Commit commits the changes made by a restore. This or Abort should always be // called. func (s *StateRestore) Commit() { - // Fire off a single KVS watch instead of a zillion prefix ones, and use - // a dumb watch manager to single-fire all the full table watches. - s.tx.Defer(func() { s.store.kvsWatch.Notify("", true) }) - s.tx.Defer(func() { s.watches.Notify() }) - s.tx.Commit() } +// AbandonCh returns a channel you can wait on to know if the state store was +// abandoned. +func (s *StateStore) AbandonCh() <-chan struct{} { + return s.abandonCh +} + +// Abandon is used to signal that the given state store has been abandoned. +// Calling this more than one time will panic. +func (s *StateStore) Abandon() { + close(s.abandonCh) +} + // maxIndex is a helper used to retrieve the highest known index // amongst a set of tables in the db. func (s *StateStore) maxIndex(tables ...string) uint64 { @@ -208,64 +212,3 @@ func indexUpdateMaxTxn(tx *memdb.Txn, idx uint64, table string) error { return nil } - -// getWatchTables returns the list of tables that should be watched and used for -// max index calculations for the given query method. This is used for all -// methods except for KVS. This will panic if the method is unknown. -func (s *StateStore) getWatchTables(method string) []string { - switch method { - case "GetNode", "Nodes": - return []string{"nodes"} - case "Services": - return []string{"services"} - case "NodeService", "NodeServices", "ServiceNodes": - return []string{"nodes", "services"} - case "NodeCheck", "NodeChecks", "ServiceChecks", "ChecksInState": - return []string{"checks"} - case "ChecksInStateByNodeMeta", "ServiceChecksByNodeMeta": - return []string{"nodes", "checks"} - case "CheckServiceNodes", "NodeInfo", "NodeDump": - return []string{"nodes", "services", "checks"} - case "SessionGet", "SessionList", "NodeSessions": - return []string{"sessions"} - case "ACLGet", "ACLList": - return []string{"acls"} - case "Coordinates": - return []string{"coordinates"} - case "PreparedQueryGet", "PreparedQueryResolve", "PreparedQueryList": - return []string{"prepared-queries"} - } - - panic(fmt.Sprintf("Unknown method %s", method)) -} - -// getTableWatch returns a full table watch for the given table. This will panic -// if the table doesn't have a full table watch. -func (s *StateStore) getTableWatch(table string) Watch { - if watch, ok := s.tableWatches[table]; ok { - return watch - } - - panic(fmt.Sprintf("Unknown watch for table %s", table)) -} - -// GetQueryWatch returns a watch for the given query method. This is -// used for all methods except for KV; you should call GetKVSWatch instead. -// This will panic if the method is unknown. -func (s *StateStore) GetQueryWatch(method string) Watch { - tables := s.getWatchTables(method) - if len(tables) == 1 { - return s.getTableWatch(tables[0]) - } - - var watches []Watch - for _, table := range tables { - watches = append(watches, s.getTableWatch(table)) - } - return NewMultiWatch(watches...) -} - -// GetKVSWatch returns a watch for the given prefix in the key value store. -func (s *StateStore) GetKVSWatch(prefix string) Watch { - return s.kvsWatch.NewPrefixWatch(prefix) -} diff --git a/consul/state/state_store_test.go b/consul/state/state_store_test.go index 5a3c781719..e58b71e6dd 100644 --- a/consul/state/state_store_test.go +++ b/consul/state/state_store_test.go @@ -4,9 +4,11 @@ import ( crand "crypto/rand" "fmt" "testing" + "time" "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/types" + "github.com/hashicorp/go-memdb" ) func testUUID() string { @@ -122,6 +124,16 @@ func testSetKey(t *testing.T, s *StateStore, idx uint64, key, value string) { } } +// watchFired is a helper for unit tests that returns if the given watch set +// fired (it doesn't care which watch actually fired). This uses a fixed +// timeout since we already expect the event happened before calling this and +// just need to distinguish a fire from a timeout. We do need a little time to +// allow the watch to set up any goroutines, though. +func watchFired(ws memdb.WatchSet) bool { + timedOut := ws.Watch(time.After(50 * time.Millisecond)) + return !timedOut +} + func TestStateStore_Restore_Abort(t *testing.T) { s := testStateStore(t) @@ -140,7 +152,7 @@ func TestStateStore_Restore_Abort(t *testing.T) { } restore.Abort() - idx, entries, err := s.KVSList("") + idx, entries, err := s.KVSList(nil, "") if err != nil { t.Fatalf("err: %s", err) } @@ -152,6 +164,17 @@ func TestStateStore_Restore_Abort(t *testing.T) { } } +func TestStateStore_Abandon(t *testing.T) { + s := testStateStore(t) + abandonCh := s.AbandonCh() + s.Abandon() + select { + case <-abandonCh: + default: + t.Fatalf("bad") + } +} + func TestStateStore_maxIndex(t *testing.T) { s := testStateStore(t) @@ -180,50 +203,3 @@ func TestStateStore_indexUpdateMaxTxn(t *testing.T) { t.Fatalf("bad max: %d", max) } } - -func TestStateStore_GetWatches(t *testing.T) { - s := testStateStore(t) - - // This test does two things - it makes sure there's no full table - // watch for KVS, and it makes sure that asking for a watch that - // doesn't exist causes a panic. - func() { - defer func() { - if r := recover(); r == nil { - t.Fatalf("didn't get expected panic") - } - }() - s.getTableWatch("kvs") - }() - - // Similar for tombstones; those don't support watches at all. - func() { - defer func() { - if r := recover(); r == nil { - t.Fatalf("didn't get expected panic") - } - }() - s.getTableWatch("tombstones") - }() - - // Make sure requesting a bogus method causes a panic. - func() { - defer func() { - if r := recover(); r == nil { - t.Fatalf("didn't get expected panic") - } - }() - s.GetQueryWatch("dogs") - }() - - // Request valid watches. - if w := s.GetQueryWatch("Nodes"); w == nil { - t.Fatalf("didn't get a watch") - } - if w := s.GetQueryWatch("NodeDump"); w == nil { - t.Fatalf("didn't get a watch") - } - if w := s.GetKVSWatch("/dogs"); w == nil { - t.Fatalf("didn't get a watch") - } -} diff --git a/consul/state/txn.go b/consul/state/txn.go index 00d7905a2c..d2b3c6f1f5 100644 --- a/consul/state/txn.go +++ b/consul/state/txn.go @@ -55,14 +55,14 @@ func (s *StateStore) txnKVS(tx *memdb.Txn, idx uint64, op *structs.TxnKVOp) (str } case structs.KVSGet: - _, entry, err = s.kvsGetTxn(tx, op.DirEnt.Key) + _, entry, err = s.kvsGetTxn(tx, nil, op.DirEnt.Key) if entry == nil && err == nil { err = fmt.Errorf("key %q doesn't exist", op.DirEnt.Key) } case structs.KVSGetTree: var entries structs.DirEntries - _, entries, err = s.kvsListTxn(tx, op.DirEnt.Key) + _, entries, err = s.kvsListTxn(tx, nil, op.DirEnt.Key) if err == nil { results := make(structs.TxnResults, 0, len(entries)) for _, e := range entries { diff --git a/consul/state/txn_test.go b/consul/state/txn_test.go index d868c13523..0949e7aa58 100644 --- a/consul/state/txn_test.go +++ b/consul/state/txn_test.go @@ -295,7 +295,7 @@ func TestStateStore_Txn_KVS(t *testing.T) { } // Pull the resulting state store contents. - idx, actual, err := s.KVSList("") + idx, actual, err := s.KVSList(nil, "") if err != nil { t.Fatalf("err: %s", err) } @@ -364,7 +364,7 @@ func TestStateStore_Txn_KVS_Rollback(t *testing.T) { // This function verifies that the state store wasn't changed. verifyStateStore := func(desc string) { - idx, actual, err := s.KVSList("") + idx, actual, err := s.KVSList(nil, "") if err != nil { t.Fatalf("err (%s): %s", desc, err) } @@ -711,84 +711,3 @@ func TestStateStore_Txn_KVS_RO_Safety(t *testing.T) { } } } - -func TestStateStore_Txn_Watches(t *testing.T) { - s := testStateStore(t) - - // Verify that a basic transaction triggers multiple watches. We call - // the same underlying methods that are called above so this is more - // of a sanity check. - verifyWatch(t, s.GetKVSWatch("multi/one"), func() { - verifyWatch(t, s.GetKVSWatch("multi/two"), func() { - ops := structs.TxnOps{ - &structs.TxnOp{ - KV: &structs.TxnKVOp{ - Verb: structs.KVSSet, - DirEnt: structs.DirEntry{ - Key: "multi/one", - Value: []byte("one"), - }, - }, - }, - &structs.TxnOp{ - KV: &structs.TxnKVOp{ - Verb: structs.KVSSet, - DirEnt: structs.DirEntry{ - Key: "multi/two", - Value: []byte("two"), - }, - }, - }, - } - results, errors := s.TxnRW(15, ops) - if len(results) != len(ops) { - t.Fatalf("bad len: %d != %d", len(results), len(ops)) - } - if len(errors) != 0 { - t.Fatalf("bad len: %d != 0", len(errors)) - } - }) - }) - - // Verify that a rolled back transaction doesn't trigger any watches. - verifyNoWatch(t, s.GetKVSWatch("multi/one"), func() { - verifyNoWatch(t, s.GetKVSWatch("multi/two"), func() { - ops := structs.TxnOps{ - &structs.TxnOp{ - KV: &structs.TxnKVOp{ - Verb: structs.KVSSet, - DirEnt: structs.DirEntry{ - Key: "multi/one", - Value: []byte("one-updated"), - }, - }, - }, - &structs.TxnOp{ - KV: &structs.TxnKVOp{ - Verb: structs.KVSSet, - DirEnt: structs.DirEntry{ - Key: "multi/two", - Value: []byte("two-updated"), - }, - }, - }, - &structs.TxnOp{ - KV: &structs.TxnKVOp{ - Verb: structs.KVSLock, - DirEnt: structs.DirEntry{ - Key: "multi/nope", - Value: []byte("nope"), - }, - }, - }, - } - results, errors := s.TxnRW(16, ops) - if len(errors) != 1 { - t.Fatalf("bad len: %d != 1", len(errors)) - } - if len(results) != 0 { - t.Fatalf("bad len: %d != 0", len(results)) - } - }) - }) -} diff --git a/consul/state/watch.go b/consul/state/watch.go deleted file mode 100644 index 93a3329b07..0000000000 --- a/consul/state/watch.go +++ /dev/null @@ -1,219 +0,0 @@ -package state - -import ( - "fmt" - "sync" - - "github.com/armon/go-radix" -) - -// Watch is the external interface that's common to all the different flavors. -type Watch interface { - // Wait registers the given channel and calls it back when the watch - // fires. - Wait(notifyCh chan struct{}) - - // Clear deregisters the given channel. - Clear(notifyCh chan struct{}) -} - -// FullTableWatch implements a single notify group for a table. -type FullTableWatch struct { - group NotifyGroup -} - -// NewFullTableWatch returns a new full table watch. -func NewFullTableWatch() *FullTableWatch { - return &FullTableWatch{} -} - -// See Watch. -func (w *FullTableWatch) Wait(notifyCh chan struct{}) { - w.group.Wait(notifyCh) -} - -// See Watch. -func (w *FullTableWatch) Clear(notifyCh chan struct{}) { - w.group.Clear(notifyCh) -} - -// Notify wakes up all the watchers registered for this table. -func (w *FullTableWatch) Notify() { - w.group.Notify() -} - -// DumbWatchManager is a wrapper that allows nested code to arm full table -// watches multiple times but fire them only once. This doesn't have any -// way to clear the state, and it's not thread-safe, so it should be used once -// and thrown away inside the context of a single thread. -type DumbWatchManager struct { - // tableWatches holds the full table watches. - tableWatches map[string]*FullTableWatch - - // armed tracks whether the table should be notified. - armed map[string]bool -} - -// NewDumbWatchManager returns a new dumb watch manager. -func NewDumbWatchManager(tableWatches map[string]*FullTableWatch) *DumbWatchManager { - return &DumbWatchManager{ - tableWatches: tableWatches, - armed: make(map[string]bool), - } -} - -// Arm arms the given table's watch. -func (d *DumbWatchManager) Arm(table string) { - if _, ok := d.tableWatches[table]; !ok { - panic(fmt.Sprintf("unknown table: %s", table)) - } - - if _, ok := d.armed[table]; !ok { - d.armed[table] = true - } -} - -// Notify fires watches for all the armed tables. -func (d *DumbWatchManager) Notify() { - for table, _ := range d.armed { - d.tableWatches[table].Notify() - } -} - -// PrefixWatch provides a Watch-compatible interface for a PrefixWatchManager, -// bound to a specific prefix. -type PrefixWatch struct { - // manager is the underlying watch manager. - manager *PrefixWatchManager - - // prefix is the prefix we are watching. - prefix string -} - -// Wait registers the given channel with the notify group for our prefix. -func (w *PrefixWatch) Wait(notifyCh chan struct{}) { - w.manager.Wait(w.prefix, notifyCh) -} - -// Clear deregisters the given channel from the the notify group for our prefix. -func (w *PrefixWatch) Clear(notifyCh chan struct{}) { - w.manager.Clear(w.prefix, notifyCh) -} - -// PrefixWatchManager maintains a notify group for each prefix, allowing for -// much more fine-grained watches. -type PrefixWatchManager struct { - // watches has the set of notify groups, organized by prefix. - watches *radix.Tree - - // lock protects the watches tree. - lock sync.Mutex -} - -// NewPrefixWatchManager returns a new prefix watch manager. -func NewPrefixWatchManager() *PrefixWatchManager { - return &PrefixWatchManager{ - watches: radix.New(), - } -} - -// NewPrefixWatch returns a Watch-compatible interface for watching the given -// prefix. -func (w *PrefixWatchManager) NewPrefixWatch(prefix string) Watch { - return &PrefixWatch{ - manager: w, - prefix: prefix, - } -} - -// Wait registers the given channel on a prefix. -func (w *PrefixWatchManager) Wait(prefix string, notifyCh chan struct{}) { - w.lock.Lock() - defer w.lock.Unlock() - - var group *NotifyGroup - if raw, ok := w.watches.Get(prefix); ok { - group = raw.(*NotifyGroup) - } else { - group = &NotifyGroup{} - w.watches.Insert(prefix, group) - } - group.Wait(notifyCh) -} - -// Clear deregisters the given channel from the notify group for a prefix (if -// one exists). -func (w *PrefixWatchManager) Clear(prefix string, notifyCh chan struct{}) { - w.lock.Lock() - defer w.lock.Unlock() - - if raw, ok := w.watches.Get(prefix); ok { - group := raw.(*NotifyGroup) - group.Clear(notifyCh) - } -} - -// Notify wakes up all the watchers associated with the given prefix. If subtree -// is true then we will also notify all the tree under the prefix, such as when -// a key is being deleted. -func (w *PrefixWatchManager) Notify(prefix string, subtree bool) { - w.lock.Lock() - defer w.lock.Unlock() - - var cleanup []string - fn := func(k string, raw interface{}) bool { - group := raw.(*NotifyGroup) - group.Notify() - if k != "" { - cleanup = append(cleanup, k) - } - return false - } - - // Invoke any watcher on the path downward to the key. - w.watches.WalkPath(prefix, fn) - - // If the entire prefix may be affected (e.g. delete tree), - // invoke the entire prefix. - if subtree { - w.watches.WalkPrefix(prefix, fn) - } - - // Delete the old notify groups. - for i := len(cleanup) - 1; i >= 0; i-- { - w.watches.Delete(cleanup[i]) - } - - // TODO (slackpad) If a watch never fires then we will never clear it - // out of the tree. The old state store had the same behavior, so this - // has been around for a while. We should probably add a prefix scan - // with a function that clears out any notify groups that are empty. -} - -// MultiWatch wraps several watches and allows any of them to trigger the -// caller. -type MultiWatch struct { - // watches holds the list of subordinate watches to forward events to. - watches []Watch -} - -// NewMultiWatch returns a new new multi watch over the given set of watches. -func NewMultiWatch(watches ...Watch) *MultiWatch { - return &MultiWatch{ - watches: watches, - } -} - -// See Watch. -func (w *MultiWatch) Wait(notifyCh chan struct{}) { - for _, watch := range w.watches { - watch.Wait(notifyCh) - } -} - -// See Watch. -func (w *MultiWatch) Clear(notifyCh chan struct{}) { - for _, watch := range w.watches { - watch.Clear(notifyCh) - } -} diff --git a/consul/state/watch_test.go b/consul/state/watch_test.go deleted file mode 100644 index 6eaf85d678..0000000000 --- a/consul/state/watch_test.go +++ /dev/null @@ -1,377 +0,0 @@ -package state - -import ( - "sort" - "strings" - "testing" -) - -// verifyWatch will set up a watch channel, call the given function, and then -// make sure the watch fires. -func verifyWatch(t *testing.T, watch Watch, fn func()) { - ch := make(chan struct{}, 1) - watch.Wait(ch) - - fn() - - select { - case <-ch: - default: - t.Fatalf("watch should have been notified") - } -} - -// verifyNoWatch will set up a watch channel, call the given function, and then -// make sure the watch never fires. -func verifyNoWatch(t *testing.T, watch Watch, fn func()) { - ch := make(chan struct{}, 1) - watch.Wait(ch) - - fn() - - select { - case <-ch: - t.Fatalf("watch should not been notified") - default: - } -} - -func TestWatch_FullTableWatch(t *testing.T) { - w := NewFullTableWatch() - - // Test the basic trigger with a single watcher. - verifyWatch(t, w, func() { - w.Notify() - }) - - // Run multiple watchers and make sure they both fire. - verifyWatch(t, w, func() { - verifyWatch(t, w, func() { - w.Notify() - }) - }) - - // Make sure clear works. - ch := make(chan struct{}, 1) - w.Wait(ch) - w.Clear(ch) - w.Notify() - select { - case <-ch: - t.Fatalf("watch should not have been notified") - default: - } - - // Make sure notify is a one shot. - w.Wait(ch) - w.Notify() - select { - case <-ch: - default: - t.Fatalf("watch should have been notified") - } - w.Notify() - select { - case <-ch: - t.Fatalf("watch should not have been notified") - default: - } -} - -func TestWatch_DumbWatchManager(t *testing.T) { - watches := map[string]*FullTableWatch{ - "alice": NewFullTableWatch(), - "bob": NewFullTableWatch(), - "carol": NewFullTableWatch(), - } - - // Notify with nothing armed and make sure nothing triggers. - func() { - w := NewDumbWatchManager(watches) - verifyNoWatch(t, watches["alice"], func() { - verifyNoWatch(t, watches["bob"], func() { - verifyNoWatch(t, watches["carol"], func() { - w.Notify() - }) - }) - }) - }() - - // Trigger one watch. - func() { - w := NewDumbWatchManager(watches) - verifyWatch(t, watches["alice"], func() { - verifyNoWatch(t, watches["bob"], func() { - verifyNoWatch(t, watches["carol"], func() { - w.Arm("alice") - w.Notify() - }) - }) - }) - }() - - // Trigger two watches. - func() { - w := NewDumbWatchManager(watches) - verifyWatch(t, watches["alice"], func() { - verifyNoWatch(t, watches["bob"], func() { - verifyWatch(t, watches["carol"], func() { - w.Arm("alice") - w.Arm("carol") - w.Notify() - }) - }) - }) - }() - - // Trigger all three watches. - func() { - w := NewDumbWatchManager(watches) - verifyWatch(t, watches["alice"], func() { - verifyWatch(t, watches["bob"], func() { - verifyWatch(t, watches["carol"], func() { - w.Arm("alice") - w.Arm("bob") - w.Arm("carol") - w.Notify() - }) - }) - }) - }() - - // Trigger multiple times. - func() { - w := NewDumbWatchManager(watches) - verifyWatch(t, watches["alice"], func() { - verifyNoWatch(t, watches["bob"], func() { - verifyNoWatch(t, watches["carol"], func() { - w.Arm("alice") - w.Arm("alice") - w.Notify() - }) - }) - }) - }() - - // Make sure it panics when asked to arm an unknown table. - func() { - defer func() { - if r := recover(); r == nil { - t.Fatalf("didn't get expected panic") - } - }() - w := NewDumbWatchManager(watches) - w.Arm("nope") - }() -} - -func verifyWatches(t *testing.T, w *PrefixWatchManager, expected string) { - var found []string - fn := func(k string, v interface{}) bool { - if k == "" { - k = "(full)" - } - found = append(found, k) - return false - } - w.watches.WalkPrefix("", fn) - - sort.Strings(found) - actual := strings.Join(found, "|") - if expected != actual { - t.Fatalf("bad: %s != %s", expected, actual) - } -} - -func TestWatch_PrefixWatchManager(t *testing.T) { - w := NewPrefixWatchManager() - verifyWatches(t, w, "") - - // This will create the watch group. - ch1 := make(chan struct{}, 1) - w.Wait("hello", ch1) - verifyWatches(t, w, "hello") - - // This will add to the existing one. - ch2 := make(chan struct{}, 1) - w.Wait("hello", ch2) - verifyWatches(t, w, "hello") - - // This will add to the existing as well. - ch3 := make(chan struct{}, 1) - w.Wait("hello", ch3) - verifyWatches(t, w, "hello") - - // Remove one of the watches. - w.Clear("hello", ch2) - verifyWatches(t, w, "hello") - - // Do "clear" for one that was never added. - ch4 := make(chan struct{}, 1) - w.Clear("hello", ch4) - verifyWatches(t, w, "hello") - - // Add a full table watch. - full := make(chan struct{}, 1) - w.Wait("", full) - verifyWatches(t, w, "(full)|hello") - - // Add another channel for a different prefix. - nope := make(chan struct{}, 1) - w.Wait("nope", nope) - verifyWatches(t, w, "(full)|hello|nope") - - // Fire off the notification and make sure channels were pinged (or not) - // as expected. - w.Notify("hello", false) - verifyWatches(t, w, "(full)|nope") - select { - case <-ch1: - default: - t.Fatalf("ch1 should have been notified") - } - select { - case <-ch2: - t.Fatalf("ch2 should not have been notified") - default: - } - select { - case <-ch3: - default: - t.Fatalf("ch3 should have been notified") - } - select { - case <-ch4: - t.Fatalf("ch4 should not have been notified") - default: - } - select { - case <-nope: - t.Fatalf("nope should not have been notified") - default: - } - select { - case <-full: - default: - t.Fatalf("full should have been notified") - } -} - -func TestWatch_PrefixWatch(t *testing.T) { - w := NewPrefixWatchManager() - - // Hit a specific key. - verifyWatch(t, w.NewPrefixWatch(""), func() { - verifyWatch(t, w.NewPrefixWatch("foo/bar/baz"), func() { - verifyNoWatch(t, w.NewPrefixWatch("foo/bar/zoo"), func() { - verifyNoWatch(t, w.NewPrefixWatch("nope"), func() { - w.Notify("foo/bar/baz", false) - }) - }) - }) - }) - - // Make sure cleanup is happening. All that should be left is the - // full-table watch and the un-fired watches. - verifyWatches(t, w, "(full)|foo/bar/zoo|nope") - - // Delete a subtree. - verifyWatch(t, w.NewPrefixWatch(""), func() { - verifyWatch(t, w.NewPrefixWatch("foo/bar/baz"), func() { - verifyWatch(t, w.NewPrefixWatch("foo/bar/zoo"), func() { - verifyNoWatch(t, w.NewPrefixWatch("nope"), func() { - w.Notify("foo/", true) - }) - }) - }) - }) - verifyWatches(t, w, "(full)|nope") - - // Hit an unknown key. - verifyWatch(t, w.NewPrefixWatch(""), func() { - verifyNoWatch(t, w.NewPrefixWatch("foo/bar/baz"), func() { - verifyNoWatch(t, w.NewPrefixWatch("foo/bar/zoo"), func() { - verifyNoWatch(t, w.NewPrefixWatch("nope"), func() { - w.Notify("not/in/there", false) - }) - }) - }) - }) - verifyWatches(t, w, "(full)|foo/bar/baz|foo/bar/zoo|nope") - - // Make sure a watch can be reused. - watch := w.NewPrefixWatch("over/and/over") - for i := 0; i < 10; i++ { - verifyWatch(t, watch, func() { - w.Notify("over/and/over", false) - }) - } -} - -type MockWatch struct { - Waits map[chan struct{}]int - Clears map[chan struct{}]int -} - -func NewMockWatch() *MockWatch { - return &MockWatch{ - Waits: make(map[chan struct{}]int), - Clears: make(map[chan struct{}]int), - } -} - -func (m *MockWatch) Wait(notifyCh chan struct{}) { - if _, ok := m.Waits[notifyCh]; ok { - m.Waits[notifyCh]++ - } else { - m.Waits[notifyCh] = 1 - } -} - -func (m *MockWatch) Clear(notifyCh chan struct{}) { - if _, ok := m.Clears[notifyCh]; ok { - m.Clears[notifyCh]++ - } else { - m.Clears[notifyCh] = 1 - } -} - -func TestWatch_MultiWatch(t *testing.T) { - w1, w2 := NewMockWatch(), NewMockWatch() - w := NewMultiWatch(w1, w2) - - // Do some activity. - c1, c2 := make(chan struct{}), make(chan struct{}) - w.Wait(c1) - w.Clear(c1) - w.Wait(c1) - w.Wait(c2) - w.Clear(c1) - w.Clear(c2) - - // Make sure all the events were forwarded. - if cnt, ok := w1.Waits[c1]; !ok || cnt != 2 { - t.Fatalf("bad: %d", w1.Waits[c1]) - } - if cnt, ok := w1.Clears[c1]; !ok || cnt != 2 { - t.Fatalf("bad: %d", w1.Clears[c1]) - } - if cnt, ok := w1.Waits[c2]; !ok || cnt != 1 { - t.Fatalf("bad: %d", w1.Waits[c2]) - } - if cnt, ok := w1.Clears[c2]; !ok || cnt != 1 { - t.Fatalf("bad: %d", w1.Clears[c2]) - } - if cnt, ok := w2.Waits[c1]; !ok || cnt != 2 { - t.Fatalf("bad: %d", w2.Waits[c1]) - } - if cnt, ok := w2.Clears[c1]; !ok || cnt != 2 { - t.Fatalf("bad: %d", w2.Clears[c1]) - } - if cnt, ok := w2.Waits[c2]; !ok || cnt != 1 { - t.Fatalf("bad: %d", w2.Waits[c2]) - } - if cnt, ok := w2.Clears[c2]; !ok || cnt != 1 { - t.Fatalf("bad: %d", w2.Clears[c2]) - } -} diff --git a/consul/txn_endpoint_test.go b/consul/txn_endpoint_test.go index b1e60021c8..e502589a01 100644 --- a/consul/txn_endpoint_test.go +++ b/consul/txn_endpoint_test.go @@ -55,7 +55,7 @@ func TestTxn_Apply(t *testing.T) { // Verify the state store directly. state := s1.fsm.State() - _, d, err := state.KVSGet("test") + _, d, err := state.KVSGet(nil, "test") if err != nil { t.Fatalf("err: %v", err) } diff --git a/vendor/github.com/hashicorp/go-immutable-radix/iradix.go b/vendor/github.com/hashicorp/go-immutable-radix/iradix.go index 8d26fc95f4..1f63f769eb 100644 --- a/vendor/github.com/hashicorp/go-immutable-radix/iradix.go +++ b/vendor/github.com/hashicorp/go-immutable-radix/iradix.go @@ -2,6 +2,7 @@ package iradix import ( "bytes" + "strings" "github.com/hashicorp/golang-lru/simplelru" ) @@ -11,7 +12,9 @@ const ( // cache used per transaction. This is used to cache the updates // to the nodes near the root, while the leaves do not need to be // cached. This is important for very large transactions to prevent - // the modified cache from growing to be enormous. + // the modified cache from growing to be enormous. This is also used + // to set the max size of the mutation notify maps since those should + // also be bounded in a similar way. defaultModifiedCache = 8192 ) @@ -27,7 +30,11 @@ type Tree struct { // New returns an empty Tree func New() *Tree { - t := &Tree{root: &Node{}} + t := &Tree{ + root: &Node{ + mutateCh: make(chan struct{}), + }, + } return t } @@ -40,75 +47,148 @@ func (t *Tree) Len() int { // atomically and returns a new tree when committed. A transaction // is not thread safe, and should only be used by a single goroutine. type Txn struct { - root *Node - size int - modified *simplelru.LRU + // root is the modified root for the transaction. + root *Node + + // snap is a snapshot of the root node for use if we have to run the + // slow notify algorithm. + snap *Node + + // size tracks the size of the tree as it is modified during the + // transaction. + size int + + // writable is a cache of writable nodes that have been created during + // the course of the transaction. This allows us to re-use the same + // nodes for further writes and avoid unnecessary copies of nodes that + // have never been exposed outside the transaction. This will only hold + // up to defaultModifiedCache number of entries. + writable *simplelru.LRU + + // trackChannels is used to hold channels that need to be notified to + // signal mutation of the tree. This will only hold up to + // defaultModifiedCache number of entries, after which we will set the + // trackOverflow flag, which will cause us to use a more expensive + // algorithm to perform the notifications. Mutation tracking is only + // performed if trackMutate is true. + trackChannels map[*chan struct{}]struct{} + trackOverflow bool + trackMutate bool } // Txn starts a new transaction that can be used to mutate the tree func (t *Tree) Txn() *Txn { txn := &Txn{ root: t.root, + snap: t.root, size: t.size, } return txn } -// writeNode returns a node to be modified, if the current -// node as already been modified during the course of -// the transaction, it is used in-place. -func (t *Txn) writeNode(n *Node) *Node { - // Ensure the modified set exists - if t.modified == nil { +// TrackMutate can be used to toggle if mutations are tracked. If this is enabled +// then notifications will be issued for affected internal nodes and leaves when +// the transaction is committed. +func (t *Txn) TrackMutate(track bool) { + t.trackMutate = track +} + +// trackChannel safely attempts to track the given mutation channel, setting the +// overflow flag if we can no longer track any more. This limits the amount of +// state that will accumulate during a transaction and we have a slower algorithm +// to switch to if we overflow. +func (t *Txn) trackChannel(ch *chan struct{}) { + // In overflow, make sure we don't store any more objects. + if t.trackOverflow { + return + } + + // Create the map on the fly when we need it. + if t.trackChannels == nil { + t.trackChannels = make(map[*chan struct{}]struct{}) + } + + // If this would overflow the state we reject it and set the flag (since + // we aren't tracking everything that's required any longer). + if len(t.trackChannels) >= defaultModifiedCache { + t.trackOverflow = true + return + } + + // Otherwise we are good to track it. + t.trackChannels[ch] = struct{}{} +} + +// writeNode returns a node to be modified, if the current node has already been +// modified during the course of the transaction, it is used in-place. Set +// forLeafUpdate to true if you are getting a write node to update the leaf, +// which will set leaf mutation tracking appropriately as well. +func (t *Txn) writeNode(n *Node, forLeafUpdate bool) *Node { + // Ensure the writable set exists. + if t.writable == nil { lru, err := simplelru.NewLRU(defaultModifiedCache, nil) if err != nil { panic(err) } - t.modified = lru + t.writable = lru } - // If this node has already been modified, we can - // continue to use it during this transaction. - if _, ok := t.modified.Get(n); ok { + // If this node has already been modified, we can continue to use it + // during this transaction. If a node gets kicked out of cache then we + // *may* notify for its mutation if we end up copying the node again, + // but we don't make any guarantees about notifying for intermediate + // mutations that were never exposed outside of a transaction. + if _, ok := t.writable.Get(n); ok { return n } - // Copy the existing node - nc := new(Node) + // Mark this node as being mutated. + if t.trackMutate { + t.trackChannel(&(n.mutateCh)) + } + + // Mark its leaf as being mutated, if appropriate. + if t.trackMutate && forLeafUpdate && n.leaf != nil { + t.trackChannel(&(n.leaf.mutateCh)) + } + + // Copy the existing node. + nc := &Node{ + mutateCh: make(chan struct{}), + leaf: n.leaf, + } if n.prefix != nil { nc.prefix = make([]byte, len(n.prefix)) copy(nc.prefix, n.prefix) } - if n.leaf != nil { - nc.leaf = new(leafNode) - *nc.leaf = *n.leaf - } if len(n.edges) != 0 { nc.edges = make([]edge, len(n.edges)) copy(nc.edges, n.edges) } - // Mark this node as modified - t.modified.Add(nc, nil) + // Mark this node as writable. + t.writable.Add(nc, nil) return nc } // insert does a recursive insertion func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface{}, bool) { - // Handle key exhaution + // Handle key exhaustion if len(search) == 0 { - nc := t.writeNode(n) + var oldVal interface{} + didUpdate := false if n.isLeaf() { - old := nc.leaf.val - nc.leaf.val = v - return nc, old, true - } else { - nc.leaf = &leafNode{ - key: k, - val: v, - } - return nc, nil, false + oldVal = n.leaf.val + didUpdate = true } + + nc := t.writeNode(n, true) + nc.leaf = &leafNode{ + mutateCh: make(chan struct{}), + key: k, + val: v, + } + return nc, oldVal, didUpdate } // Look for the edge @@ -119,14 +199,16 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface e := edge{ label: search[0], node: &Node{ + mutateCh: make(chan struct{}), leaf: &leafNode{ - key: k, - val: v, + mutateCh: make(chan struct{}), + key: k, + val: v, }, prefix: search, }, } - nc := t.writeNode(n) + nc := t.writeNode(n, false) nc.addEdge(e) return nc, nil, false } @@ -137,7 +219,7 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface search = search[commonPrefix:] newChild, oldVal, didUpdate := t.insert(child, k, search, v) if newChild != nil { - nc := t.writeNode(n) + nc := t.writeNode(n, false) nc.edges[idx].node = newChild return nc, oldVal, didUpdate } @@ -145,9 +227,10 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface } // Split the node - nc := t.writeNode(n) + nc := t.writeNode(n, false) splitNode := &Node{ - prefix: search[:commonPrefix], + mutateCh: make(chan struct{}), + prefix: search[:commonPrefix], } nc.replaceEdge(edge{ label: search[0], @@ -155,7 +238,7 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface }) // Restore the existing child node - modChild := t.writeNode(child) + modChild := t.writeNode(child, false) splitNode.addEdge(edge{ label: modChild.prefix[commonPrefix], node: modChild, @@ -164,8 +247,9 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface // Create a new leaf node leaf := &leafNode{ - key: k, - val: v, + mutateCh: make(chan struct{}), + key: k, + val: v, } // If the new key is a subset, add to to this node @@ -179,8 +263,9 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface splitNode.addEdge(edge{ label: search[0], node: &Node{ - leaf: leaf, - prefix: search, + mutateCh: make(chan struct{}), + leaf: leaf, + prefix: search, }, }) return nc, nil, false @@ -188,14 +273,14 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface // delete does a recursive deletion func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) { - // Check for key exhaution + // Check for key exhaustion if len(search) == 0 { if !n.isLeaf() { return nil, nil } // Remove the leaf node - nc := t.writeNode(n) + nc := t.writeNode(n, true) nc.leaf = nil // Check if this node should be merged @@ -219,8 +304,11 @@ func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) { return nil, nil } - // Copy this node - nc := t.writeNode(n) + // Copy this node. WATCH OUT - it's safe to pass "false" here because we + // will only ADD a leaf via nc.mergeChilde() if there isn't one due to + // the !nc.isLeaf() check in the logic just below. This is pretty subtle, + // so be careful if you change any of the logic here. + nc := t.writeNode(n, false) // Delete the edge if the node has no edges if newChild.leaf == nil && len(newChild.edges) == 0 { @@ -274,10 +362,109 @@ func (t *Txn) Get(k []byte) (interface{}, bool) { return t.root.Get(k) } -// Commit is used to finalize the transaction and return a new tree +// GetWatch is used to lookup a specific key, returning +// the watch channel, value and if it was found +func (t *Txn) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) { + return t.root.GetWatch(k) +} + +// Commit is used to finalize the transaction and return a new tree. If mutation +// tracking is turned on then notifications will also be issued. func (t *Txn) Commit() *Tree { - t.modified = nil - return &Tree{t.root, t.size} + nt := t.commit() + if t.trackMutate { + t.notify() + } + return nt +} + +// commit is an internal helper for Commit(), useful for unit tests. +func (t *Txn) commit() *Tree { + nt := &Tree{t.root, t.size} + t.writable = nil + return nt +} + +// slowNotify does a complete comparison of the before and after trees in order +// to trigger notifications. This doesn't require any additional state but it +// is very expensive to compute. +func (t *Txn) slowNotify() { + snapIter := t.snap.rawIterator() + rootIter := t.root.rawIterator() + for snapIter.Front() != nil || rootIter.Front() != nil { + // If we've exhausted the nodes in the old snapshot, we know + // there's nothing remaining to notify. + if snapIter.Front() == nil { + return + } + snapElem := snapIter.Front() + + // If we've exhausted the nodes in the new root, we know we need + // to invalidate everything that remains in the old snapshot. We + // know from the loop condition there's something in the old + // snapshot. + if rootIter.Front() == nil { + close(snapElem.mutateCh) + if snapElem.isLeaf() { + close(snapElem.leaf.mutateCh) + } + snapIter.Next() + continue + } + + // Do one string compare so we can check the various conditions + // below without repeating the compare. + cmp := strings.Compare(snapIter.Path(), rootIter.Path()) + + // If the snapshot is behind the root, then we must have deleted + // this node during the transaction. + if cmp < 0 { + close(snapElem.mutateCh) + if snapElem.isLeaf() { + close(snapElem.leaf.mutateCh) + } + snapIter.Next() + continue + } + + // If the snapshot is ahead of the root, then we must have added + // this node during the transaction. + if cmp > 0 { + rootIter.Next() + continue + } + + // If we have the same path, then we need to see if we mutated a + // node and possibly the leaf. + rootElem := rootIter.Front() + if snapElem != rootElem { + close(snapElem.mutateCh) + if snapElem.leaf != nil && (snapElem.leaf != rootElem.leaf) { + close(snapElem.leaf.mutateCh) + } + } + snapIter.Next() + rootIter.Next() + } +} + +// notify is used along with TrackMutate to trigger notifications. This should +// only be done once a transaction is committed. +func (t *Txn) notify() { + // If we've overflowed the tracking state we can't use it in any way and + // need to do a full tree compare. + if t.trackOverflow { + t.slowNotify() + } else { + for ch := range t.trackChannels { + close(*ch) + } + } + + // Clean up the tracking state so that a re-notify is safe (will trigger + // the else clause above which will be a no-op). + t.trackChannels = nil + t.trackOverflow = false } // Insert is used to add or update a given key. The return provides diff --git a/vendor/github.com/hashicorp/go-immutable-radix/iter.go b/vendor/github.com/hashicorp/go-immutable-radix/iter.go index 75cbaa110f..9815e02538 100644 --- a/vendor/github.com/hashicorp/go-immutable-radix/iter.go +++ b/vendor/github.com/hashicorp/go-immutable-radix/iter.go @@ -9,11 +9,13 @@ type Iterator struct { stack []edges } -// SeekPrefix is used to seek the iterator to a given prefix -func (i *Iterator) SeekPrefix(prefix []byte) { +// SeekPrefixWatch is used to seek the iterator to a given prefix +// and returns the watch channel of the finest granularity +func (i *Iterator) SeekPrefixWatch(prefix []byte) (watch <-chan struct{}) { // Wipe the stack i.stack = nil n := i.node + watch = n.mutateCh search := prefix for { // Check for key exhaution @@ -29,6 +31,9 @@ func (i *Iterator) SeekPrefix(prefix []byte) { return } + // Update to the finest granularity as the search makes progress + watch = n.mutateCh + // Consume the search prefix if bytes.HasPrefix(search, n.prefix) { search = search[len(n.prefix):] @@ -43,6 +48,11 @@ func (i *Iterator) SeekPrefix(prefix []byte) { } } +// SeekPrefix is used to seek the iterator to a given prefix +func (i *Iterator) SeekPrefix(prefix []byte) { + i.SeekPrefixWatch(prefix) +} + // Next returns the next node in order func (i *Iterator) Next() ([]byte, interface{}, bool) { // Initialize our stack if needed diff --git a/vendor/github.com/hashicorp/go-immutable-radix/node.go b/vendor/github.com/hashicorp/go-immutable-radix/node.go index fea6f63436..cf7137f93c 100644 --- a/vendor/github.com/hashicorp/go-immutable-radix/node.go +++ b/vendor/github.com/hashicorp/go-immutable-radix/node.go @@ -12,8 +12,9 @@ type WalkFn func(k []byte, v interface{}) bool // leafNode is used to represent a value type leafNode struct { - key []byte - val interface{} + mutateCh chan struct{} + key []byte + val interface{} } // edge is used to represent an edge node @@ -24,6 +25,9 @@ type edge struct { // Node is an immutable node in the radix tree type Node struct { + // mutateCh is closed if this node is modified + mutateCh chan struct{} + // leaf is used to store possible leaf leaf *leafNode @@ -105,13 +109,14 @@ func (n *Node) mergeChild() { } } -func (n *Node) Get(k []byte) (interface{}, bool) { +func (n *Node) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) { search := k + watch := n.mutateCh for { - // Check for key exhaution + // Check for key exhaustion if len(search) == 0 { if n.isLeaf() { - return n.leaf.val, true + return n.leaf.mutateCh, n.leaf.val, true } break } @@ -122,6 +127,9 @@ func (n *Node) Get(k []byte) (interface{}, bool) { break } + // Update to the finest granularity as the search makes progress + watch = n.mutateCh + // Consume the search prefix if bytes.HasPrefix(search, n.prefix) { search = search[len(n.prefix):] @@ -129,7 +137,12 @@ func (n *Node) Get(k []byte) (interface{}, bool) { break } } - return nil, false + return watch, nil, false +} + +func (n *Node) Get(k []byte) (interface{}, bool) { + _, val, ok := n.GetWatch(k) + return val, ok } // LongestPrefix is like Get, but instead of an @@ -204,6 +217,14 @@ func (n *Node) Iterator() *Iterator { return &Iterator{node: n} } +// rawIterator is used to return a raw iterator at the given node to walk the +// tree. +func (n *Node) rawIterator() *rawIterator { + iter := &rawIterator{node: n} + iter.Next() + return iter +} + // Walk is used to walk the tree func (n *Node) Walk(fn WalkFn) { recursiveWalk(n, fn) diff --git a/vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go b/vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go new file mode 100644 index 0000000000..04814c1323 --- /dev/null +++ b/vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go @@ -0,0 +1,78 @@ +package iradix + +// rawIterator visits each of the nodes in the tree, even the ones that are not +// leaves. It keeps track of the effective path (what a leaf at a given node +// would be called), which is useful for comparing trees. +type rawIterator struct { + // node is the starting node in the tree for the iterator. + node *Node + + // stack keeps track of edges in the frontier. + stack []rawStackEntry + + // pos is the current position of the iterator. + pos *Node + + // path is the effective path of the current iterator position, + // regardless of whether the current node is a leaf. + path string +} + +// rawStackEntry is used to keep track of the cumulative common path as well as +// its associated edges in the frontier. +type rawStackEntry struct { + path string + edges edges +} + +// Front returns the current node that has been iterated to. +func (i *rawIterator) Front() *Node { + return i.pos +} + +// Path returns the effective path of the current node, even if it's not actually +// a leaf. +func (i *rawIterator) Path() string { + return i.path +} + +// Next advances the iterator to the next node. +func (i *rawIterator) Next() { + // Initialize our stack if needed. + if i.stack == nil && i.node != nil { + i.stack = []rawStackEntry{ + rawStackEntry{ + edges: edges{ + edge{node: i.node}, + }, + }, + } + } + + for len(i.stack) > 0 { + // Inspect the last element of the stack. + n := len(i.stack) + last := i.stack[n-1] + elem := last.edges[0].node + + // Update the stack. + if len(last.edges) > 1 { + i.stack[n-1].edges = last.edges[1:] + } else { + i.stack = i.stack[:n-1] + } + + // Push the edges onto the frontier. + if len(elem.edges) > 0 { + path := last.path + string(elem.prefix) + i.stack = append(i.stack, rawStackEntry{path, elem.edges}) + } + + i.pos = elem + i.path = last.path + string(elem.prefix) + return + } + + i.pos = nil + i.path = "" +} diff --git a/vendor/github.com/hashicorp/go-memdb/memdb.go b/vendor/github.com/hashicorp/go-memdb/memdb.go index 1d708517db..13817547be 100644 --- a/vendor/github.com/hashicorp/go-memdb/memdb.go +++ b/vendor/github.com/hashicorp/go-memdb/memdb.go @@ -15,6 +15,7 @@ import ( type MemDB struct { schema *DBSchema root unsafe.Pointer // *iradix.Tree underneath + primary bool // There can only be a single writter at once writer sync.Mutex @@ -31,6 +32,7 @@ func NewMemDB(schema *DBSchema) (*MemDB, error) { db := &MemDB{ schema: schema, root: unsafe.Pointer(iradix.New()), + primary: true, } if err := db.initialize(); err != nil { return nil, err @@ -65,6 +67,7 @@ func (db *MemDB) Snapshot() *MemDB { clone := &MemDB{ schema: db.schema, root: unsafe.Pointer(db.getRoot()), + primary: false, } return clone } diff --git a/vendor/github.com/hashicorp/go-memdb/schema.go b/vendor/github.com/hashicorp/go-memdb/schema.go index 26d0fcb99f..d7210f91cd 100644 --- a/vendor/github.com/hashicorp/go-memdb/schema.go +++ b/vendor/github.com/hashicorp/go-memdb/schema.go @@ -38,7 +38,7 @@ func (s *TableSchema) Validate() error { return fmt.Errorf("missing table name") } if len(s.Indexes) == 0 { - return fmt.Errorf("missing table schemas for '%s'", s.Name) + return fmt.Errorf("missing table indexes for '%s'", s.Name) } if _, ok := s.Indexes["id"]; !ok { return fmt.Errorf("must have id index") diff --git a/vendor/github.com/hashicorp/go-memdb/txn.go b/vendor/github.com/hashicorp/go-memdb/txn.go index fa73c9a3f1..a069a9fd99 100644 --- a/vendor/github.com/hashicorp/go-memdb/txn.go +++ b/vendor/github.com/hashicorp/go-memdb/txn.go @@ -70,6 +70,11 @@ func (txn *Txn) writableIndex(table, index string) *iradix.Txn { raw, _ := txn.rootTxn.Get(path) indexTxn := raw.(*iradix.Tree).Txn() + // If we are the primary DB, enable mutation tracking. Snapshots should + // not notify, otherwise we will trigger watches on the primary DB when + // the writes will not be visible. + indexTxn.TrackMutate(txn.db.primary) + // Keep this open for the duration of the txn txn.modified[key] = indexTxn return indexTxn @@ -352,13 +357,13 @@ func (txn *Txn) DeleteAll(table, index string, args ...interface{}) (int, error) return num, nil } -// First is used to return the first matching object for -// the given constraints on the index -func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) { +// FirstWatch is used to return the first matching object for +// the given constraints on the index along with the watch channel +func (txn *Txn) FirstWatch(table, index string, args ...interface{}) (<-chan struct{}, interface{}, error) { // Get the index value indexSchema, val, err := txn.getIndexValue(table, index, args...) if err != nil { - return nil, err + return nil, nil, err } // Get the index itself @@ -366,18 +371,25 @@ func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, er // Do an exact lookup if indexSchema.Unique && val != nil && indexSchema.Name == index { - obj, ok := indexTxn.Get(val) + watch, obj, ok := indexTxn.GetWatch(val) if !ok { - return nil, nil + return watch, nil, nil } - return obj, nil + return watch, obj, nil } // Handle non-unique index by using an iterator and getting the first value iter := indexTxn.Root().Iterator() - iter.SeekPrefix(val) + watch := iter.SeekPrefixWatch(val) _, value, _ := iter.Next() - return value, nil + return watch, value, nil +} + +// First is used to return the first matching object for +// the given constraints on the index +func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) { + _, val, err := txn.FirstWatch(table, index, args...) + return val, err } // LongestPrefix is used to fetch the longest prefix match for the given @@ -468,6 +480,7 @@ func (txn *Txn) getIndexValue(table, index string, args ...interface{}) (*IndexS // ResultIterator is used to iterate over a list of results // from a Get query on a table. type ResultIterator interface { + WatchCh() <-chan struct{} Next() interface{} } @@ -488,11 +501,12 @@ func (txn *Txn) Get(table, index string, args ...interface{}) (ResultIterator, e indexIter := indexRoot.Iterator() // Seek the iterator to the appropriate sub-set - indexIter.SeekPrefix(val) + watchCh := indexIter.SeekPrefixWatch(val) // Create an iterator iter := &radixIterator{ - iter: indexIter, + iter: indexIter, + watchCh: watchCh, } return iter, nil } @@ -506,10 +520,15 @@ func (txn *Txn) Defer(fn func()) { } // radixIterator is used to wrap an underlying iradix iterator. -// This is much mroe efficient than a sliceIterator as we are not +// This is much more efficient than a sliceIterator as we are not // materializing the entire view. type radixIterator struct { - iter *iradix.Iterator + iter *iradix.Iterator + watchCh <-chan struct{} +} + +func (r *radixIterator) WatchCh() <-chan struct{} { + return r.watchCh } func (r *radixIterator) Next() interface{} { diff --git a/vendor/github.com/hashicorp/go-memdb/watch.go b/vendor/github.com/hashicorp/go-memdb/watch.go new file mode 100644 index 0000000000..7c4a3ba6ee --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/watch.go @@ -0,0 +1,108 @@ +package memdb + +import "time" + +// WatchSet is a collection of watch channels. +type WatchSet map[<-chan struct{}]struct{} + +// NewWatchSet constructs a new watch set. +func NewWatchSet() WatchSet { + return make(map[<-chan struct{}]struct{}) +} + +// Add appends a watchCh to the WatchSet if non-nil. +func (w WatchSet) Add(watchCh <-chan struct{}) { + if w == nil { + return + } + + if _, ok := w[watchCh]; !ok { + w[watchCh] = struct{}{} + } +} + +// AddWithLimit appends a watchCh to the WatchSet if non-nil, and if the given +// softLimit hasn't been exceeded. Otherwise, it will watch the given alternate +// channel. It's expected that the altCh will be the same on many calls to this +// function, so you will exceed the soft limit a little bit if you hit this, but +// not by much. +// +// This is useful if you want to track individual items up to some limit, after +// which you watch a higher-level channel (usually a channel from start start of +// an iterator higher up in the radix tree) that will watch a superset of items. +func (w WatchSet) AddWithLimit(softLimit int, watchCh <-chan struct{}, altCh <-chan struct{}) { + // This is safe for a nil WatchSet so we don't need to check that here. + if len(w) < softLimit { + w.Add(watchCh) + } else { + w.Add(altCh) + } +} + +// Watch is used to wait for either the watch set to trigger or a timeout. +// Returns true on timeout. +func (w WatchSet) Watch(timeoutCh <-chan time.Time) bool { + if w == nil { + return false + } + + if n := len(w); n <= aFew { + idx := 0 + chunk := make([]<-chan struct{}, aFew) + for watchCh := range w { + chunk[idx] = watchCh + idx++ + } + return watchFew(chunk, timeoutCh) + } else { + return w.watchMany(timeoutCh) + } +} + +// watchMany is used if there are many watchers. +func (w WatchSet) watchMany(timeoutCh <-chan time.Time) bool { + // Make a fake timeout channel we can feed into watchFew to cancel all + // the blocking goroutines. + doneCh := make(chan time.Time) + defer close(doneCh) + + // Set up a goroutine for each watcher. + triggerCh := make(chan struct{}, 1) + watcher := func(chunk []<-chan struct{}) { + if timeout := watchFew(chunk, doneCh); !timeout { + select { + case triggerCh <- struct{}{}: + default: + } + } + } + + // Apportion the watch channels into chunks we can feed into the + // watchFew helper. + idx := 0 + chunk := make([]<-chan struct{}, aFew) + for watchCh := range w { + subIdx := idx % aFew + chunk[subIdx] = watchCh + idx++ + + // Fire off this chunk and start a fresh one. + if idx%aFew == 0 { + go watcher(chunk) + chunk = make([]<-chan struct{}, aFew) + } + } + + // Make sure to watch any residual channels in the last chunk. + if idx%aFew != 0 { + go watcher(chunk) + } + + // Wait for a channel to trigger or timeout. + select { + case <-triggerCh: + return false + case <-timeoutCh: + return true + } +} diff --git a/vendor/github.com/hashicorp/go-memdb/watch_few.go b/vendor/github.com/hashicorp/go-memdb/watch_few.go new file mode 100644 index 0000000000..f2bb19db17 --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/watch_few.go @@ -0,0 +1,116 @@ +//go:generate sh -c "go run watch-gen/main.go >watch_few.go" +package memdb + +import( + "time" +) + +// aFew gives how many watchers this function is wired to support. You must +// always pass a full slice of this length, but unused channels can be nil. +const aFew = 32 + +// watchFew is used if there are only a few watchers as a performance +// optimization. +func watchFew(ch []<-chan struct{}, timeoutCh <-chan time.Time) bool { + select { + + case <-ch[0]: + return false + + case <-ch[1]: + return false + + case <-ch[2]: + return false + + case <-ch[3]: + return false + + case <-ch[4]: + return false + + case <-ch[5]: + return false + + case <-ch[6]: + return false + + case <-ch[7]: + return false + + case <-ch[8]: + return false + + case <-ch[9]: + return false + + case <-ch[10]: + return false + + case <-ch[11]: + return false + + case <-ch[12]: + return false + + case <-ch[13]: + return false + + case <-ch[14]: + return false + + case <-ch[15]: + return false + + case <-ch[16]: + return false + + case <-ch[17]: + return false + + case <-ch[18]: + return false + + case <-ch[19]: + return false + + case <-ch[20]: + return false + + case <-ch[21]: + return false + + case <-ch[22]: + return false + + case <-ch[23]: + return false + + case <-ch[24]: + return false + + case <-ch[25]: + return false + + case <-ch[26]: + return false + + case <-ch[27]: + return false + + case <-ch[28]: + return false + + case <-ch[29]: + return false + + case <-ch[30]: + return false + + case <-ch[31]: + return false + + case <-timeoutCh: + return true + } +} diff --git a/vendor/vendor.json b/vendor/vendor.json index 47533cfa7a..7f732334c0 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -426,16 +426,16 @@ "revisionTime": "2016-04-07T17:41:26Z" }, { - "checksumSHA1": "qmE9mO0WW6ALLpUU81rXDyspP5M=", + "checksumSHA1": "jPxyofQxI1PRPq6LPc6VlcRn5fI=", "path": "github.com/hashicorp/go-immutable-radix", - "revision": "afc5a0dbb18abdf82c277a7bc01533e81fa1d6b8", - "revisionTime": "2016-06-09T02:05:29Z" + "revision": "76b5f4e390910df355bfb9b16b41899538594a05", + "revisionTime": "2017-01-13T02:29:29Z" }, { - "checksumSHA1": "ZpTDFeRvXFwIvSHRD8eDYHxaj4Y=", + "checksumSHA1": "K8Fsgt1llTXP0EwqdBzvSGdKOKc=", "path": "github.com/hashicorp/go-memdb", - "revision": "d2d2b77acab85aa635614ac17ea865969f56009e", - "revisionTime": "2017-01-07T16:22:14Z" + "revision": "c01f56b44823e8ba697e23c18d12dca984b85aca", + "revisionTime": "2017-01-23T15:32:28Z" }, { "checksumSHA1": "TNlVzNR1OaajcNi3CbQ3bGbaLGU=",