diff --git a/consul/state/acl.go b/consul/state/acl.go new file mode 100644 index 0000000000..3ce94e9a13 --- /dev/null +++ b/consul/state/acl.go @@ -0,0 +1,172 @@ +package state + +import ( + "fmt" + + "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" +) + +// ACLs is used to pull all the ACLs from the snapshot. +func (s *StateSnapshot) ACLs() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("acls", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +// ACL is used when restoring from a snapshot. For general inserts, use ACLSet. +func (s *StateRestore) ACL(acl *structs.ACL) error { + if err := s.tx.Insert("acls", acl); err != nil { + return fmt.Errorf("failed restoring acl: %s", err) + } + + if err := indexUpdateMaxTxn(s.tx, acl.ModifyIndex, "acls"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + s.watches.Arm("acls") + return nil +} + +// ACLSet is used to insert an ACL rule into the state store. +func (s *StateStore) ACLSet(idx uint64, acl *structs.ACL) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call set on the ACL + if err := s.aclSetTxn(tx, idx, acl); err != nil { + return err + } + + tx.Commit() + return nil +} + +// aclSetTxn is the inner method used to insert an ACL rule with the +// proper indexes into the state store. +func (s *StateStore) aclSetTxn(tx *memdb.Txn, idx uint64, acl *structs.ACL) error { + // Check that the ID is set + if acl.ID == "" { + return ErrMissingACLID + } + + // Check for an existing ACL + existing, err := tx.First("acls", "id", acl.ID) + if err != nil { + return fmt.Errorf("failed acl lookup: %s", err) + } + + // Set the indexes + if existing != nil { + acl.CreateIndex = existing.(*structs.ACL).CreateIndex + acl.ModifyIndex = idx + } else { + acl.CreateIndex = idx + acl.ModifyIndex = idx + } + + // Insert the ACL + if err := tx.Insert("acls", acl); err != nil { + return fmt.Errorf("failed inserting acl: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"acls", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Defer(func() { s.tableWatches["acls"].Notify() }) + return nil +} + +// ACLGet is used to look up an existing ACL by ID. +func (s *StateStore) ACLGet(aclID string) (uint64, *structs.ACL, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ACLGet")...) + + // Query for the existing ACL + acl, err := tx.First("acls", "id", aclID) + if err != nil { + return 0, nil, fmt.Errorf("failed acl lookup: %s", err) + } + if acl != nil { + return idx, acl.(*structs.ACL), nil + } + return idx, nil, nil +} + +// ACLList is used to list out all of the ACLs in the state store. +func (s *StateStore) ACLList() (uint64, structs.ACLs, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ACLList")...) + + // Return the ACLs. + acls, err := s.aclListTxn(tx) + if err != nil { + return 0, nil, fmt.Errorf("failed acl lookup: %s", err) + } + return idx, acls, nil +} + +// aclListTxn is used to list out all of the ACLs in the state store. This is a +// function vs. a method so it can be called from the snapshotter. +func (s *StateStore) aclListTxn(tx *memdb.Txn) (structs.ACLs, error) { + // Query all of the ACLs in the state store + acls, err := tx.Get("acls", "id") + if err != nil { + return nil, fmt.Errorf("failed acl lookup: %s", err) + } + + // Go over all of the ACLs and build the response + var result structs.ACLs + for acl := acls.Next(); acl != nil; acl = acls.Next() { + a := acl.(*structs.ACL) + result = append(result, a) + } + return result, nil +} + +// ACLDelete is used to remove an existing ACL from the state store. If +// the ACL does not exist this is a no-op and no error is returned. +func (s *StateStore) ACLDelete(idx uint64, aclID string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the ACL delete + if err := s.aclDeleteTxn(tx, idx, aclID); err != nil { + return err + } + + tx.Commit() + return nil +} + +// aclDeleteTxn is used to delete an ACL from the state store within +// an existing transaction. +func (s *StateStore) aclDeleteTxn(tx *memdb.Txn, idx uint64, aclID string) error { + // Look up the existing ACL + acl, err := tx.First("acls", "id", aclID) + if err != nil { + return fmt.Errorf("failed acl lookup: %s", err) + } + if acl == nil { + return nil + } + + // Delete the ACL from the state store and update indexes + if err := tx.Delete("acls", acl); err != nil { + return fmt.Errorf("failed deleting acl: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"acls", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Defer(func() { s.tableWatches["acls"].Notify() }) + return nil +} diff --git a/consul/state/acl_test.go b/consul/state/acl_test.go new file mode 100644 index 0000000000..94bab3fedd --- /dev/null +++ b/consul/state/acl_test.go @@ -0,0 +1,298 @@ +package state + +import ( + "reflect" + "testing" + + "github.com/hashicorp/consul/consul/structs" +) + +func TestStateStore_ACLSet_ACLGet(t *testing.T) { + s := testStateStore(t) + + // Querying ACLs with no results returns nil + idx, res, err := s.ACLGet("nope") + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Inserting an ACL with empty ID is disallowed + if err := s.ACLSet(1, &structs.ACL{}); err == nil { + t.Fatalf("expected %#v, got: %#v", ErrMissingACLID, err) + } + + // Index is not updated if nothing is saved + if idx := s.maxIndex("acls"); idx != 0 { + t.Fatalf("bad index: %d", idx) + } + + // Inserting valid ACL works + acl := &structs.ACL{ + ID: "acl1", + Name: "First ACL", + Type: structs.ACLTypeClient, + Rules: "rules1", + } + if err := s.ACLSet(1, acl); err != nil { + t.Fatalf("err: %s", err) + } + + // Check that the index was updated + if idx := s.maxIndex("acls"); idx != 1 { + t.Fatalf("bad index: %d", idx) + } + + // Retrieve the ACL again + idx, result, err := s.ACLGet("acl1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 1 { + t.Fatalf("bad index: %d", idx) + } + + // Check that the ACL matches the result + expect := &structs.ACL{ + ID: "acl1", + Name: "First ACL", + Type: structs.ACLTypeClient, + Rules: "rules1", + RaftIndex: structs.RaftIndex{ + CreateIndex: 1, + ModifyIndex: 1, + }, + } + if !reflect.DeepEqual(result, expect) { + t.Fatalf("bad: %#v", result) + } + + // Update the ACL + acl = &structs.ACL{ + ID: "acl1", + Name: "First ACL", + Type: structs.ACLTypeClient, + Rules: "rules2", + } + if err := s.ACLSet(2, acl); err != nil { + t.Fatalf("err: %s", err) + } + + // Index was updated + if idx := s.maxIndex("acls"); idx != 2 { + t.Fatalf("bad: %d", idx) + } + + // ACL was updated and matches expected value + expect = &structs.ACL{ + ID: "acl1", + Name: "First ACL", + Type: structs.ACLTypeClient, + Rules: "rules2", + RaftIndex: structs.RaftIndex{ + CreateIndex: 1, + ModifyIndex: 2, + }, + } + if !reflect.DeepEqual(acl, expect) { + t.Fatalf("bad: %#v", acl) + } +} + +func TestStateStore_ACLList(t *testing.T) { + s := testStateStore(t) + + // Listing when no ACLs exist returns nil + idx, res, err := s.ACLList() + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Insert some ACLs + acls := structs.ACLs{ + &structs.ACL{ + ID: "acl1", + Type: structs.ACLTypeClient, + Rules: "rules1", + RaftIndex: structs.RaftIndex{ + CreateIndex: 1, + ModifyIndex: 1, + }, + }, + &structs.ACL{ + ID: "acl2", + Type: structs.ACLTypeClient, + Rules: "rules2", + RaftIndex: structs.RaftIndex{ + CreateIndex: 2, + ModifyIndex: 2, + }, + }, + } + for _, acl := range acls { + if err := s.ACLSet(acl.ModifyIndex, acl); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Query the ACLs + idx, res, err = s.ACLList() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + // Check that the result matches + if !reflect.DeepEqual(res, acls) { + t.Fatalf("bad: %#v", res) + } +} + +func TestStateStore_ACLDelete(t *testing.T) { + s := testStateStore(t) + + // Calling delete on an ACL which doesn't exist returns nil + if err := s.ACLDelete(1, "nope"); err != nil { + t.Fatalf("err: %s", err) + } + + // Index isn't updated if nothing is deleted + if idx := s.maxIndex("acls"); idx != 0 { + t.Fatalf("bad index: %d", idx) + } + + // Insert an ACL + if err := s.ACLSet(1, &structs.ACL{ID: "acl1"}); err != nil { + t.Fatalf("err: %s", err) + } + + // Delete the ACL and check that the index was updated + if err := s.ACLDelete(2, "acl1"); err != nil { + t.Fatalf("err: %s", err) + } + if idx := s.maxIndex("acls"); idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + tx := s.db.Txn(false) + defer tx.Abort() + + // Check that the ACL was really deleted + result, err := tx.First("acls", "id", "acl1") + if err != nil { + t.Fatalf("err: %s", err) + } + if result != nil { + t.Fatalf("expected nil, got: %#v", result) + } +} + +func TestStateStore_ACL_Snapshot_Restore(t *testing.T) { + s := testStateStore(t) + + // Insert some ACLs. + acls := structs.ACLs{ + &structs.ACL{ + ID: "acl1", + Type: structs.ACLTypeClient, + Rules: "rules1", + RaftIndex: structs.RaftIndex{ + CreateIndex: 1, + ModifyIndex: 1, + }, + }, + &structs.ACL{ + ID: "acl2", + Type: structs.ACLTypeClient, + Rules: "rules2", + RaftIndex: structs.RaftIndex{ + CreateIndex: 2, + ModifyIndex: 2, + }, + }, + } + for _, acl := range acls { + if err := s.ACLSet(acl.ModifyIndex, acl); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Snapshot the ACLs. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + if err := s.ACLDelete(3, "acl1"); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 2 { + t.Fatalf("bad index: %d", idx) + } + iter, err := snap.ACLs() + if err != nil { + t.Fatalf("err: %s", err) + } + var dump structs.ACLs + for acl := iter.Next(); acl != nil; acl = iter.Next() { + dump = append(dump, acl.(*structs.ACL)) + } + if !reflect.DeepEqual(dump, acls) { + t.Fatalf("bad: %#v", dump) + } + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + restore := s.Restore() + for _, acl := range dump { + if err := restore.ACL(acl); err != nil { + t.Fatalf("err: %s", err) + } + } + restore.Commit() + + // Read the restored ACLs back out and verify that they match. + idx, res, err := s.ACLList() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + if !reflect.DeepEqual(res, acls) { + t.Fatalf("bad: %#v", res) + } + + // Check that the index was updated. + if idx := s.maxIndex("acls"); idx != 2 { + t.Fatalf("bad index: %d", idx) + } + }() +} + +func TestStateStore_ACL_Watches(t *testing.T) { + s := testStateStore(t) + + // Call functions that update the acls table and make sure a watch fires + // each time. + verifyWatch(t, s.getTableWatch("acls"), func() { + if err := s.ACLSet(1, &structs.ACL{ID: "acl1"}); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("acls"), func() { + if err := s.ACLDelete(2, "acl1"); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("acls"), func() { + restore := s.Restore() + if err := restore.ACL(&structs.ACL{ID: "acl1"}); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + }) +} diff --git a/consul/state/catalog.go b/consul/state/catalog.go new file mode 100644 index 0000000000..56191dc91f --- /dev/null +++ b/consul/state/catalog.go @@ -0,0 +1,1159 @@ +package state + +import ( + "fmt" + "strings" + + "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/consul/types" + "github.com/hashicorp/go-memdb" +) + +// Nodes is used to pull the full list of nodes for use during snapshots. +func (s *StateSnapshot) Nodes() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("nodes", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +// Services is used to pull the full list of services for a given node for use +// during snapshots. +func (s *StateSnapshot) Services(node string) (memdb.ResultIterator, error) { + iter, err := s.tx.Get("services", "node", node) + if err != nil { + return nil, err + } + return iter, nil +} + +// Checks is used to pull the full list of checks for a given node for use +// during snapshots. +func (s *StateSnapshot) Checks(node string) (memdb.ResultIterator, error) { + iter, err := s.tx.Get("checks", "node", node) + if err != nil { + return nil, err + } + return iter, nil +} + +// Registration is used to make sure a node, service, and check registration is +// performed within a single transaction to avoid race conditions on state +// updates. +func (s *StateRestore) Registration(idx uint64, req *structs.RegisterRequest) error { + if err := s.store.ensureRegistrationTxn(s.tx, idx, s.watches, req); err != nil { + return err + } + return nil +} + +// EnsureRegistration is used to make sure a node, service, and check +// registration is performed within a single transaction to avoid race +// conditions on state updates. +func (s *StateStore) EnsureRegistration(idx uint64, req *structs.RegisterRequest) error { + tx := s.db.Txn(true) + defer tx.Abort() + + watches := NewDumbWatchManager(s.tableWatches) + if err := s.ensureRegistrationTxn(tx, idx, watches, req); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// ensureRegistrationTxn is used to make sure a node, service, and check +// registration is performed within a single transaction to avoid race +// conditions on state updates. +func (s *StateStore) ensureRegistrationTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, + req *structs.RegisterRequest) error { + // Add the node. + node := &structs.Node{ + Node: req.Node, + Address: req.Address, + TaggedAddresses: req.TaggedAddresses, + Meta: req.NodeMeta, + } + if err := s.ensureNodeTxn(tx, idx, watches, node); err != nil { + return fmt.Errorf("failed inserting node: %s", err) + } + + // Add the service, if any. + if req.Service != nil { + if err := s.ensureServiceTxn(tx, idx, watches, req.Node, req.Service); err != nil { + return fmt.Errorf("failed inserting service: %s", err) + } + } + + // TODO (slackpad) In Consul 0.8 ban checks that don't have the same + // node as the top-level registration. This is just weird to be able to + // update unrelated nodes' checks from in here. In 0.7.2 we banned this + // up in the ACL check since that's guarded behind an opt-in flag until + // Consul 0.8. + + // Add the checks, if any. + if req.Check != nil { + if err := s.ensureCheckTxn(tx, idx, watches, req.Check); err != nil { + return fmt.Errorf("failed inserting check: %s", err) + } + } + for _, check := range req.Checks { + if err := s.ensureCheckTxn(tx, idx, watches, check); err != nil { + return fmt.Errorf("failed inserting check: %s", err) + } + } + + return nil +} + +// EnsureNode is used to upsert node registration or modification. +func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the node upsert + watches := NewDumbWatchManager(s.tableWatches) + if err := s.ensureNodeTxn(tx, idx, watches, node); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// ensureNodeTxn is the inner function called to actually create a node +// registration or modify an existing one in the state store. It allows +// passing in a memdb transaction so it may be part of a larger txn. +func (s *StateStore) ensureNodeTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, + node *structs.Node) error { + // Check for an existing node + existing, err := tx.First("nodes", "id", node.Node) + if err != nil { + return fmt.Errorf("node lookup failed: %s", err) + } + + // Get the indexes + if existing != nil { + node.CreateIndex = existing.(*structs.Node).CreateIndex + node.ModifyIndex = idx + } else { + node.CreateIndex = idx + node.ModifyIndex = idx + } + + // Insert the node and update the index + if err := tx.Insert("nodes", node); err != nil { + return fmt.Errorf("failed inserting node: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"nodes", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + watches.Arm("nodes") + return nil +} + +// GetNode is used to retrieve a node registration by node ID. +func (s *StateStore) GetNode(id string) (uint64, *structs.Node, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("GetNode")...) + + // Retrieve the node from the state store + node, err := tx.First("nodes", "id", id) + if err != nil { + return 0, nil, fmt.Errorf("node lookup failed: %s", err) + } + if node != nil { + return idx, node.(*structs.Node), nil + } + return idx, nil, nil +} + +// Nodes is used to return all of the known nodes. +func (s *StateStore) Nodes() (uint64, structs.Nodes, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("Nodes")...) + + // Retrieve all of the nodes + nodes, err := tx.Get("nodes", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed nodes lookup: %s", err) + } + + // Create and return the nodes list. + var results structs.Nodes + for node := nodes.Next(); node != nil; node = nodes.Next() { + results = append(results, node.(*structs.Node)) + } + return idx, results, nil +} + +// NodesByMeta is used to return all nodes with the given meta key/value pair. +func (s *StateStore) NodesByMeta(filters map[string]string) (uint64, structs.Nodes, error) { + if len(filters) > 1 { + return 0, nil, fmt.Errorf("multiple meta filters not supported") + } + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("Nodes")...) + + // Retrieve all of the nodes + var args []interface{} + for key, value := range filters { + args = append(args, key, value) + } + nodes, err := tx.Get("nodes", "meta", args...) + if err != nil { + return 0, nil, fmt.Errorf("failed nodes lookup: %s", err) + } + + // Create and return the nodes list. + var results structs.Nodes + for node := nodes.Next(); node != nil; node = nodes.Next() { + results = append(results, node.(*structs.Node)) + } + return idx, results, nil +} + +// DeleteNode is used to delete a given node by its ID. +func (s *StateStore) DeleteNode(idx uint64, nodeID string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the node deletion. + if err := s.deleteNodeTxn(tx, idx, nodeID); err != nil { + return err + } + + tx.Commit() + return nil +} + +// deleteNodeTxn is the inner method used for removing a node from +// the store within a given transaction. +func (s *StateStore) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeID string) error { + // Look up the node. + node, err := tx.First("nodes", "id", nodeID) + if err != nil { + return fmt.Errorf("node lookup failed: %s", err) + } + if node == nil { + return nil + } + + // Use a watch manager since the inner functions can perform multiple + // ops per table. + watches := NewDumbWatchManager(s.tableWatches) + + // Delete all services associated with the node and update the service index. + services, err := tx.Get("services", "node", nodeID) + if err != nil { + return fmt.Errorf("failed service lookup: %s", err) + } + var sids []string + for service := services.Next(); service != nil; service = services.Next() { + sids = append(sids, service.(*structs.ServiceNode).ServiceID) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, sid := range sids { + if err := s.deleteServiceTxn(tx, idx, watches, nodeID, sid); err != nil { + return err + } + } + + // Delete all checks associated with the node. This will invalidate + // sessions as necessary. + checks, err := tx.Get("checks", "node", nodeID) + if err != nil { + return fmt.Errorf("failed check lookup: %s", err) + } + var cids []types.CheckID + for check := checks.Next(); check != nil; check = checks.Next() { + cids = append(cids, check.(*structs.HealthCheck).CheckID) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, cid := range cids { + if err := s.deleteCheckTxn(tx, idx, watches, nodeID, cid); err != nil { + return err + } + } + + // Delete any coordinate associated with this node. + coord, err := tx.First("coordinates", "id", nodeID) + if err != nil { + return fmt.Errorf("failed coordinate lookup: %s", err) + } + if coord != nil { + if err := tx.Delete("coordinates", coord); err != nil { + return fmt.Errorf("failed deleting coordinate: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"coordinates", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + watches.Arm("coordinates") + } + + // Delete the node and update the index. + if err := tx.Delete("nodes", node); err != nil { + return fmt.Errorf("failed deleting node: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"nodes", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + // Invalidate any sessions for this node. + sessions, err := tx.Get("sessions", "node", nodeID) + if err != nil { + return fmt.Errorf("failed session lookup: %s", err) + } + var ids []string + for sess := sessions.Next(); sess != nil; sess = sessions.Next() { + ids = append(ids, sess.(*structs.Session).ID) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, id := range ids { + if err := s.deleteSessionTxn(tx, idx, watches, id); err != nil { + return fmt.Errorf("failed session delete: %s", err) + } + } + + watches.Arm("nodes") + tx.Defer(func() { watches.Notify() }) + return nil +} + +// EnsureService is called to upsert creation of a given NodeService. +func (s *StateStore) EnsureService(idx uint64, node string, svc *structs.NodeService) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the service registration upsert + watches := NewDumbWatchManager(s.tableWatches) + if err := s.ensureServiceTxn(tx, idx, watches, node, svc); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// ensureServiceTxn is used to upsert a service registration within an +// existing memdb transaction. +func (s *StateStore) ensureServiceTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, + node string, svc *structs.NodeService) error { + // Check for existing service + existing, err := tx.First("services", "id", node, svc.ID) + if err != nil { + return fmt.Errorf("failed service lookup: %s", err) + } + + // Create the service node entry and populate the indexes. Note that + // conversion doesn't populate any of the node-specific information + // (Address and TaggedAddresses). That's always populated when we read + // from the state store. + entry := svc.ToServiceNode(node) + if existing != nil { + entry.CreateIndex = existing.(*structs.ServiceNode).CreateIndex + entry.ModifyIndex = idx + } else { + entry.CreateIndex = idx + entry.ModifyIndex = idx + } + + // Get the node + n, err := tx.First("nodes", "id", node) + if err != nil { + return fmt.Errorf("failed node lookup: %s", err) + } + if n == nil { + return ErrMissingNode + } + + // Insert the service and update the index + if err := tx.Insert("services", entry); err != nil { + return fmt.Errorf("failed inserting service: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"services", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + watches.Arm("services") + return nil +} + +// Services returns all services along with a list of associated tags. +func (s *StateStore) Services() (uint64, structs.Services, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("Services")...) + + // List all the services. + services, err := tx.Get("services", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed querying services: %s", err) + } + + // Rip through the services and enumerate them and their unique set of + // tags. + unique := make(map[string]map[string]struct{}) + for service := services.Next(); service != nil; service = services.Next() { + svc := service.(*structs.ServiceNode) + tags, ok := unique[svc.ServiceName] + if !ok { + unique[svc.ServiceName] = make(map[string]struct{}) + tags = unique[svc.ServiceName] + } + for _, tag := range svc.ServiceTags { + tags[tag] = struct{}{} + } + } + + // Generate the output structure. + var results = make(structs.Services) + for service, tags := range unique { + results[service] = make([]string, 0) + for tag, _ := range tags { + results[service] = append(results[service], tag) + } + } + return idx, results, nil +} + +// Services returns all services, filtered by the given node metadata. +func (s *StateStore) ServicesByNodeMeta(filters map[string]string) (uint64, structs.Services, error) { + if len(filters) > 1 { + return 0, nil, fmt.Errorf("multiple meta filters not supported") + } + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ServiceNodes")...) + + // Retrieve all of the nodes with the meta k/v pair + var args []interface{} + for key, value := range filters { + args = append(args, key, value) + } + nodes, err := tx.Get("nodes", "meta", args...) + if err != nil { + return 0, nil, fmt.Errorf("failed nodes lookup: %s", err) + } + + // Populate the services map + unique := make(map[string]map[string]struct{}) + for node := nodes.Next(); node != nil; node = nodes.Next() { + n := node.(*structs.Node) + // List all the services on the node + services, err := tx.Get("services", "node", n.Node) + if err != nil { + return 0, nil, fmt.Errorf("failed querying services: %s", err) + } + + // Rip through the services and enumerate them and their unique set of + // tags. + for service := services.Next(); service != nil; service = services.Next() { + svc := service.(*structs.ServiceNode) + tags, ok := unique[svc.ServiceName] + if !ok { + unique[svc.ServiceName] = make(map[string]struct{}) + tags = unique[svc.ServiceName] + } + for _, tag := range svc.ServiceTags { + tags[tag] = struct{}{} + } + } + } + + // Generate the output structure. + var results = make(structs.Services) + for service, tags := range unique { + results[service] = make([]string, 0) + for tag, _ := range tags { + results[service] = append(results[service], tag) + } + } + return idx, results, nil +} + +// ServiceNodes returns the nodes associated with a given service name. +func (s *StateStore) ServiceNodes(serviceName string) (uint64, structs.ServiceNodes, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ServiceNodes")...) + + // List all the services. + services, err := tx.Get("services", "service", serviceName) + if err != nil { + return 0, nil, fmt.Errorf("failed service lookup: %s", err) + } + var results structs.ServiceNodes + for service := services.Next(); service != nil; service = services.Next() { + results = append(results, service.(*structs.ServiceNode)) + } + + // Fill in the address details. + results, err = s.parseServiceNodes(tx, results) + if err != nil { + return 0, nil, fmt.Errorf("failed parsing service nodes: %s", err) + } + return idx, results, nil +} + +// ServiceTagNodes returns the nodes associated with a given service, filtering +// out services that don't contain the given tag. +func (s *StateStore) ServiceTagNodes(service, tag string) (uint64, structs.ServiceNodes, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ServiceNodes")...) + + // List all the services. + services, err := tx.Get("services", "service", service) + if err != nil { + return 0, nil, fmt.Errorf("failed service lookup: %s", err) + } + + // Gather all the services and apply the tag filter. + var results structs.ServiceNodes + for service := services.Next(); service != nil; service = services.Next() { + svc := service.(*structs.ServiceNode) + if !serviceTagFilter(svc, tag) { + results = append(results, svc) + } + } + + // Fill in the address details. + results, err = s.parseServiceNodes(tx, results) + if err != nil { + return 0, nil, fmt.Errorf("failed parsing service nodes: %s", err) + } + return idx, results, nil +} + +// serviceTagFilter returns true (should filter) if the given service node +// doesn't contain the given tag. +func serviceTagFilter(sn *structs.ServiceNode, tag string) bool { + tag = strings.ToLower(tag) + + // Look for the lower cased version of the tag. + for _, t := range sn.ServiceTags { + if strings.ToLower(t) == tag { + return false + } + } + + // If we didn't hit the tag above then we should filter. + return true +} + +// parseServiceNodes iterates over a services query and fills in the node details, +// returning a ServiceNodes slice. +func (s *StateStore) parseServiceNodes(tx *memdb.Txn, services structs.ServiceNodes) (structs.ServiceNodes, error) { + var results structs.ServiceNodes + for _, sn := range services { + // Note that we have to clone here because we don't want to + // modify the node-related fields on the object in the database, + // which is what we are referencing. + s := sn.PartialClone() + + // Grab the corresponding node record. + n, err := tx.First("nodes", "id", sn.Node) + if err != nil { + return nil, fmt.Errorf("failed node lookup: %s", err) + } + + // Populate the node-related fields. The tagged addresses may be + // used by agents to perform address translation if they are + // configured to do that. + node := n.(*structs.Node) + s.Address = node.Address + s.TaggedAddresses = node.TaggedAddresses + s.NodeMeta = node.Meta + + results = append(results, s) + } + return results, nil +} + +// NodeService is used to retrieve a specific service associated with the given +// node. +func (s *StateStore) NodeService(nodeID string, serviceID string) (uint64, *structs.NodeService, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("NodeService")...) + + // Query the service + service, err := tx.First("services", "id", nodeID, serviceID) + if err != nil { + return 0, nil, fmt.Errorf("failed querying service for node %q: %s", nodeID, err) + } + + if service != nil { + return idx, service.(*structs.ServiceNode).ToNodeService(), nil + } else { + return idx, nil, nil + } +} + +// NodeServices is used to query service registrations by node ID. +func (s *StateStore) NodeServices(nodeID string) (uint64, *structs.NodeServices, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("NodeServices")...) + + // Query the node + n, err := tx.First("nodes", "id", nodeID) + if err != nil { + return 0, nil, fmt.Errorf("node lookup failed: %s", err) + } + if n == nil { + return 0, nil, nil + } + node := n.(*structs.Node) + + // Read all of the services + services, err := tx.Get("services", "node", nodeID) + if err != nil { + return 0, nil, fmt.Errorf("failed querying services for node %q: %s", nodeID, err) + } + + // Initialize the node services struct + ns := &structs.NodeServices{ + Node: node, + Services: make(map[string]*structs.NodeService), + } + + // Add all of the services to the map. + for service := services.Next(); service != nil; service = services.Next() { + svc := service.(*structs.ServiceNode).ToNodeService() + ns.Services[svc.ID] = svc + } + + return idx, ns, nil +} + +// DeleteService is used to delete a given service associated with a node. +func (s *StateStore) DeleteService(idx uint64, nodeID, serviceID string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the service deletion + watches := NewDumbWatchManager(s.tableWatches) + if err := s.deleteServiceTxn(tx, idx, watches, nodeID, serviceID); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// deleteServiceTxn is the inner method called to remove a service +// registration within an existing transaction. +func (s *StateStore) deleteServiceTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, nodeID, serviceID string) error { + // Look up the service. + service, err := tx.First("services", "id", nodeID, serviceID) + if err != nil { + return fmt.Errorf("failed service lookup: %s", err) + } + if service == nil { + return nil + } + + // Delete any checks associated with the service. This will invalidate + // sessions as necessary. + checks, err := tx.Get("checks", "node_service", nodeID, serviceID) + if err != nil { + return fmt.Errorf("failed service check lookup: %s", err) + } + var cids []types.CheckID + for check := checks.Next(); check != nil; check = checks.Next() { + cids = append(cids, check.(*structs.HealthCheck).CheckID) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, cid := range cids { + if err := s.deleteCheckTxn(tx, idx, watches, nodeID, cid); err != nil { + return err + } + } + + // Update the index. + if err := tx.Insert("index", &IndexEntry{"checks", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + // Delete the service and update the index + if err := tx.Delete("services", service); err != nil { + return fmt.Errorf("failed deleting service: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"services", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + watches.Arm("services") + return nil +} + +// EnsureCheck is used to store a check registration in the db. +func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the check registration + watches := NewDumbWatchManager(s.tableWatches) + if err := s.ensureCheckTxn(tx, idx, watches, hc); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// ensureCheckTransaction is used as the inner method to handle inserting +// a health check into the state store. It ensures safety against inserting +// checks with no matching node or service. +func (s *StateStore) ensureCheckTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, + hc *structs.HealthCheck) error { + // Check if we have an existing health check + existing, err := tx.First("checks", "id", hc.Node, string(hc.CheckID)) + if err != nil { + return fmt.Errorf("failed health check lookup: %s", err) + } + + // Set the indexes + if existing != nil { + hc.CreateIndex = existing.(*structs.HealthCheck).CreateIndex + hc.ModifyIndex = idx + } else { + hc.CreateIndex = idx + hc.ModifyIndex = idx + } + + // Use the default check status if none was provided + if hc.Status == "" { + hc.Status = structs.HealthCritical + } + + // Get the node + node, err := tx.First("nodes", "id", hc.Node) + if err != nil { + return fmt.Errorf("failed node lookup: %s", err) + } + if node == nil { + return ErrMissingNode + } + + // If the check is associated with a service, check that we have + // a registration for the service. + if hc.ServiceID != "" { + service, err := tx.First("services", "id", hc.Node, hc.ServiceID) + if err != nil { + return fmt.Errorf("failed service lookup: %s", err) + } + if service == nil { + return ErrMissingService + } + + // Copy in the service name + hc.ServiceName = service.(*structs.ServiceNode).ServiceName + } + + // Delete any sessions for this check if the health is critical. + if hc.Status == structs.HealthCritical { + mappings, err := tx.Get("session_checks", "node_check", hc.Node, string(hc.CheckID)) + if err != nil { + return fmt.Errorf("failed session checks lookup: %s", err) + } + + var ids []string + for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { + ids = append(ids, mapping.(*sessionCheck).Session) + } + + // Delete the session in a separate loop so we don't trash the + // iterator. + watches := NewDumbWatchManager(s.tableWatches) + for _, id := range ids { + if err := s.deleteSessionTxn(tx, idx, watches, id); err != nil { + return fmt.Errorf("failed deleting session: %s", err) + } + } + tx.Defer(func() { watches.Notify() }) + } + + // Persist the check registration in the db. + if err := tx.Insert("checks", hc); err != nil { + return fmt.Errorf("failed inserting check: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"checks", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + watches.Arm("checks") + return nil +} + +// NodeCheck is used to retrieve a specific check associated with the given +// node. +func (s *StateStore) NodeCheck(nodeID string, checkID types.CheckID) (uint64, *structs.HealthCheck, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("NodeCheck")...) + + // Return the check. + check, err := tx.First("checks", "id", nodeID, string(checkID)) + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + if check != nil { + return idx, check.(*structs.HealthCheck), nil + } else { + return idx, nil, nil + } +} + +// NodeChecks is used to retrieve checks associated with the +// given node from the state store. +func (s *StateStore) NodeChecks(nodeID string) (uint64, structs.HealthChecks, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("NodeChecks")...) + + // Return the checks. + checks, err := tx.Get("checks", "node", nodeID) + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + return s.parseChecks(idx, checks) +} + +// ServiceChecks is used to get all checks associated with a +// given service ID. The query is performed against a service +// _name_ instead of a service ID. +func (s *StateStore) ServiceChecks(serviceName string) (uint64, structs.HealthChecks, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ServiceChecks")...) + + // Return the checks. + checks, err := tx.Get("checks", "service", serviceName) + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + return s.parseChecks(idx, checks) +} + +// ChecksInState is used to query the state store for all checks +// which are in the provided state. +func (s *StateStore) ChecksInState(state string) (uint64, structs.HealthChecks, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ChecksInState")...) + + // Query all checks if HealthAny is passed + if state == structs.HealthAny { + checks, err := tx.Get("checks", "status") + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + return s.parseChecks(idx, checks) + } + + // Any other state we need to query for explicitly + checks, err := tx.Get("checks", "status", state) + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + return s.parseChecks(idx, checks) +} + +// parseChecks is a helper function used to deduplicate some +// repetitive code for returning health checks. +func (s *StateStore) parseChecks(idx uint64, iter memdb.ResultIterator) (uint64, structs.HealthChecks, error) { + // Gather the health checks and return them properly type casted. + var results structs.HealthChecks + for check := iter.Next(); check != nil; check = iter.Next() { + results = append(results, check.(*structs.HealthCheck)) + } + return idx, results, nil +} + +// DeleteCheck is used to delete a health check registration. +func (s *StateStore) DeleteCheck(idx uint64, node string, checkID types.CheckID) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the check deletion + watches := NewDumbWatchManager(s.tableWatches) + if err := s.deleteCheckTxn(tx, idx, watches, node, checkID); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// deleteCheckTxn is the inner method used to call a health +// check deletion within an existing transaction. +func (s *StateStore) deleteCheckTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, node string, checkID types.CheckID) error { + // Try to retrieve the existing health check. + hc, err := tx.First("checks", "id", node, string(checkID)) + if err != nil { + return fmt.Errorf("check lookup failed: %s", err) + } + if hc == nil { + return nil + } + + // Delete the check from the DB and update the index. + if err := tx.Delete("checks", hc); err != nil { + return fmt.Errorf("failed removing check: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"checks", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + // Delete any sessions for this check. + mappings, err := tx.Get("session_checks", "node_check", node, string(checkID)) + if err != nil { + return fmt.Errorf("failed session checks lookup: %s", err) + } + var ids []string + for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { + ids = append(ids, mapping.(*sessionCheck).Session) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, id := range ids { + if err := s.deleteSessionTxn(tx, idx, watches, id); err != nil { + return fmt.Errorf("failed deleting session: %s", err) + } + } + + watches.Arm("checks") + return nil +} + +// CheckServiceNodes is used to query all nodes and checks for a given service +// The results are compounded into a CheckServiceNodes, and the index returned +// is the maximum index observed over any node, check, or service in the result +// set. +func (s *StateStore) CheckServiceNodes(serviceName string) (uint64, structs.CheckServiceNodes, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("CheckServiceNodes")...) + + // Query the state store for the service. + services, err := tx.Get("services", "service", serviceName) + if err != nil { + return 0, nil, fmt.Errorf("failed service lookup: %s", err) + } + + // Return the results. + var results structs.ServiceNodes + for service := services.Next(); service != nil; service = services.Next() { + results = append(results, service.(*structs.ServiceNode)) + } + return s.parseCheckServiceNodes(tx, idx, results, err) +} + +// CheckServiceTagNodes is used to query all nodes and checks for a given +// service, filtering out services that don't contain the given tag. The results +// are compounded into a CheckServiceNodes, and the index returned is the maximum +// index observed over any node, check, or service in the result set. +func (s *StateStore) CheckServiceTagNodes(serviceName, tag string) (uint64, structs.CheckServiceNodes, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("CheckServiceNodes")...) + + // Query the state store for the service. + services, err := tx.Get("services", "service", serviceName) + if err != nil { + return 0, nil, fmt.Errorf("failed service lookup: %s", err) + } + + // Return the results, filtering by tag. + var results structs.ServiceNodes + for service := services.Next(); service != nil; service = services.Next() { + svc := service.(*structs.ServiceNode) + if !serviceTagFilter(svc, tag) { + results = append(results, svc) + } + } + return s.parseCheckServiceNodes(tx, idx, results, err) +} + +// parseCheckServiceNodes is used to parse through a given set of services, +// and query for an associated node and a set of checks. This is the inner +// method used to return a rich set of results from a more simple query. +func (s *StateStore) parseCheckServiceNodes( + tx *memdb.Txn, idx uint64, services structs.ServiceNodes, + err error) (uint64, structs.CheckServiceNodes, error) { + if err != nil { + return 0, nil, err + } + + // Special-case the zero return value to nil, since this ends up in + // external APIs. + if len(services) == 0 { + return idx, nil, nil + } + + results := make(structs.CheckServiceNodes, 0, len(services)) + for _, sn := range services { + // Retrieve the node. + n, err := tx.First("nodes", "id", sn.Node) + if err != nil { + return 0, nil, fmt.Errorf("failed node lookup: %s", err) + } + if n == nil { + return 0, nil, ErrMissingNode + } + node := n.(*structs.Node) + + // We need to return the checks specific to the given service + // as well as the node itself. Unfortunately, memdb won't let + // us use the index to do the latter query so we have to pull + // them all and filter. + var checks structs.HealthChecks + iter, err := tx.Get("checks", "node", sn.Node) + if err != nil { + return 0, nil, err + } + for check := iter.Next(); check != nil; check = iter.Next() { + hc := check.(*structs.HealthCheck) + if hc.ServiceID == "" || hc.ServiceID == sn.ServiceID { + checks = append(checks, hc) + } + } + + // Append to the results. + results = append(results, structs.CheckServiceNode{ + Node: node, + Service: sn.ToNodeService(), + Checks: checks, + }) + } + + return idx, results, nil +} + +// NodeInfo is used to generate a dump of a single node. The dump includes +// all services and checks which are registered against the node. +func (s *StateStore) NodeInfo(node string) (uint64, structs.NodeDump, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("NodeInfo")...) + + // Query the node by the passed node + nodes, err := tx.Get("nodes", "id", node) + if err != nil { + return 0, nil, fmt.Errorf("failed node lookup: %s", err) + } + return s.parseNodes(tx, idx, nodes) +} + +// NodeDump is used to generate a dump of all nodes. This call is expensive +// as it has to query every node, service, and check. The response can also +// be quite large since there is currently no filtering applied. +func (s *StateStore) NodeDump() (uint64, structs.NodeDump, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("NodeDump")...) + + // Fetch all of the registered nodes + nodes, err := tx.Get("nodes", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed node lookup: %s", err) + } + return s.parseNodes(tx, idx, nodes) +} + +// parseNodes takes an iterator over a set of nodes and returns a struct +// containing the nodes along with all of their associated services +// and/or health checks. +func (s *StateStore) parseNodes(tx *memdb.Txn, idx uint64, + iter memdb.ResultIterator) (uint64, structs.NodeDump, error) { + + var results structs.NodeDump + for n := iter.Next(); n != nil; n = iter.Next() { + node := n.(*structs.Node) + + // Create the wrapped node + dump := &structs.NodeInfo{ + Node: node.Node, + Address: node.Address, + TaggedAddresses: node.TaggedAddresses, + Meta: node.Meta, + } + + // Query the node services + services, err := tx.Get("services", "node", node.Node) + if err != nil { + return 0, nil, fmt.Errorf("failed services lookup: %s", err) + } + for service := services.Next(); service != nil; service = services.Next() { + ns := service.(*structs.ServiceNode).ToNodeService() + dump.Services = append(dump.Services, ns) + } + + // Query the node checks + checks, err := tx.Get("checks", "node", node.Node) + if err != nil { + return 0, nil, fmt.Errorf("failed node lookup: %s", err) + } + for check := checks.Next(); check != nil; check = checks.Next() { + hc := check.(*structs.HealthCheck) + dump.Checks = append(dump.Checks, hc) + } + + // Add the result to the slice + results = append(results, dump) + } + return idx, results, nil +} diff --git a/consul/state/catalog_test.go b/consul/state/catalog_test.go new file mode 100644 index 0000000000..43631e19a5 --- /dev/null +++ b/consul/state/catalog_test.go @@ -0,0 +1,2043 @@ +package state + +import ( + "fmt" + "reflect" + "sort" + "testing" + + "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/consul/lib" +) + +func TestStateStore_EnsureRegistration(t *testing.T) { + s := testStateStore(t) + + // Start with just a node. + req := &structs.RegisterRequest{ + Node: "node1", + Address: "1.2.3.4", + TaggedAddresses: map[string]string{ + "hello": "world", + }, + NodeMeta: map[string]string{ + "somekey": "somevalue", + }, + } + if err := s.EnsureRegistration(1, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the node and verify its contents. + verifyNode := func(created, modified uint64) { + _, out, err := s.GetNode("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if out.Node != "node1" || out.Address != "1.2.3.4" || + len(out.TaggedAddresses) != 1 || + out.TaggedAddresses["hello"] != "world" || + out.Meta["somekey"] != "somevalue" || + out.CreateIndex != created || out.ModifyIndex != modified { + t.Fatalf("bad node returned: %#v", out) + } + } + verifyNode(1, 1) + + // Add in a service definition. + req.Service = &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + } + if err := s.EnsureRegistration(2, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify that the service got registered. + verifyService := func(created, modified uint64) { + idx, out, err := s.NodeServices("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != modified { + t.Fatalf("bad index: %d", idx) + } + if len(out.Services) != 1 { + t.Fatalf("bad: %#v", out.Services) + } + r := out.Services["redis1"] + if r == nil || r.ID != "redis1" || r.Service != "redis" || + r.Address != "1.1.1.1" || r.Port != 8080 || + r.CreateIndex != created || r.ModifyIndex != modified { + t.Fatalf("bad service returned: %#v", r) + } + + idx, r, err = s.NodeService("node1", "redis1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != modified { + t.Fatalf("bad index: %d", idx) + } + if r == nil || r.ID != "redis1" || r.Service != "redis" || + r.Address != "1.1.1.1" || r.Port != 8080 || + r.CreateIndex != created || r.ModifyIndex != modified { + t.Fatalf("bad service returned: %#v", r) + } + } + verifyNode(1, 2) + verifyService(2, 2) + + // Add in a top-level check. + req.Check = &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Name: "check", + } + if err := s.EnsureRegistration(3, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify that the check got registered. + verifyCheck := func(created, modified uint64) { + idx, out, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != modified { + t.Fatalf("bad index: %d", idx) + } + if len(out) != 1 { + t.Fatalf("bad: %#v", out) + } + c := out[0] + if c.Node != "node1" || c.CheckID != "check1" || c.Name != "check" || + c.CreateIndex != created || c.ModifyIndex != modified { + t.Fatalf("bad check returned: %#v", c) + } + + idx, c, err = s.NodeCheck("node1", "check1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != modified { + t.Fatalf("bad index: %d", idx) + } + if c.Node != "node1" || c.CheckID != "check1" || c.Name != "check" || + c.CreateIndex != created || c.ModifyIndex != modified { + t.Fatalf("bad check returned: %#v", c) + } + } + verifyNode(1, 3) + verifyService(2, 3) + verifyCheck(3, 3) + + // Add in another check via the slice. + req.Checks = structs.HealthChecks{ + &structs.HealthCheck{ + Node: "node1", + CheckID: "check2", + Name: "check", + }, + } + if err := s.EnsureRegistration(4, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify that the additional check got registered. + verifyNode(1, 4) + verifyService(2, 4) + func() { + idx, out, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + if len(out) != 2 { + t.Fatalf("bad: %#v", out) + } + c1 := out[0] + if c1.Node != "node1" || c1.CheckID != "check1" || c1.Name != "check" || + c1.CreateIndex != 3 || c1.ModifyIndex != 4 { + t.Fatalf("bad check returned: %#v", c1) + } + + c2 := out[1] + if c2.Node != "node1" || c2.CheckID != "check2" || c2.Name != "check" || + c2.CreateIndex != 4 || c2.ModifyIndex != 4 { + t.Fatalf("bad check returned: %#v", c2) + } + }() +} + +func TestStateStore_EnsureRegistration_Restore(t *testing.T) { + s := testStateStore(t) + + // Start with just a node. + req := &structs.RegisterRequest{ + Node: "node1", + Address: "1.2.3.4", + } + restore := s.Restore() + if err := restore.Registration(1, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + + // Retrieve the node and verify its contents. + verifyNode := func(created, modified uint64) { + _, out, err := s.GetNode("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if out.Node != "node1" || out.Address != "1.2.3.4" || + out.CreateIndex != created || out.ModifyIndex != modified { + t.Fatalf("bad node returned: %#v", out) + } + } + verifyNode(1, 1) + + // Add in a service definition. + req.Service = &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + } + restore = s.Restore() + if err := restore.Registration(2, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + + // Verify that the service got registered. + verifyService := func(created, modified uint64) { + idx, out, err := s.NodeServices("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != modified { + t.Fatalf("bad index: %d", idx) + } + if len(out.Services) != 1 { + t.Fatalf("bad: %#v", out.Services) + } + s := out.Services["redis1"] + if s.ID != "redis1" || s.Service != "redis" || + s.Address != "1.1.1.1" || s.Port != 8080 || + s.CreateIndex != created || s.ModifyIndex != modified { + t.Fatalf("bad service returned: %#v", s) + } + } + verifyNode(1, 2) + verifyService(2, 2) + + // Add in a top-level check. + req.Check = &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Name: "check", + } + restore = s.Restore() + if err := restore.Registration(3, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + + // Verify that the check got registered. + verifyCheck := func(created, modified uint64) { + idx, out, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != modified { + t.Fatalf("bad index: %d", idx) + } + if len(out) != 1 { + t.Fatalf("bad: %#v", out) + } + c := out[0] + if c.Node != "node1" || c.CheckID != "check1" || c.Name != "check" || + c.CreateIndex != created || c.ModifyIndex != modified { + t.Fatalf("bad check returned: %#v", c) + } + } + verifyNode(1, 3) + verifyService(2, 3) + verifyCheck(3, 3) + + // Add in another check via the slice. + req.Checks = structs.HealthChecks{ + &structs.HealthCheck{ + Node: "node1", + CheckID: "check2", + Name: "check", + }, + } + restore = s.Restore() + if err := restore.Registration(4, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + + // Verify that the additional check got registered. + verifyNode(1, 4) + verifyService(2, 4) + func() { + idx, out, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + if len(out) != 2 { + t.Fatalf("bad: %#v", out) + } + c1 := out[0] + if c1.Node != "node1" || c1.CheckID != "check1" || c1.Name != "check" || + c1.CreateIndex != 3 || c1.ModifyIndex != 4 { + t.Fatalf("bad check returned: %#v", c1) + } + + c2 := out[1] + if c2.Node != "node1" || c2.CheckID != "check2" || c2.Name != "check" || + c2.CreateIndex != 4 || c2.ModifyIndex != 4 { + t.Fatalf("bad check returned: %#v", c2) + } + }() +} + +func TestStateStore_EnsureRegistration_Watches(t *testing.T) { + s := testStateStore(t) + + req := &structs.RegisterRequest{ + Node: "node1", + Address: "1.2.3.4", + } + + // The nodes watch should fire for this one. + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyNoWatch(t, s.getTableWatch("services"), func() { + verifyNoWatch(t, s.getTableWatch("checks"), func() { + if err := s.EnsureRegistration(1, req); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + // The nodes watch should fire for this one. + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyNoWatch(t, s.getTableWatch("services"), func() { + verifyNoWatch(t, s.getTableWatch("checks"), func() { + restore := s.Restore() + if err := restore.Registration(1, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + }) + }) + }) + + // With a service definition added it should fire nodes and + // services. + req.Service = &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + } + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyNoWatch(t, s.getTableWatch("checks"), func() { + if err := s.EnsureRegistration(2, req); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyNoWatch(t, s.getTableWatch("checks"), func() { + restore := s.Restore() + if err := restore.Registration(2, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + }) + }) + }) + + // Now with a check it should hit all three. + req.Check = &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Name: "check", + } + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.EnsureRegistration(3, req); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + restore := s.Restore() + if err := restore.Registration(3, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + }) + }) + }) +} + +func TestStateStore_EnsureNode(t *testing.T) { + s := testStateStore(t) + + // Fetching a non-existent node returns nil + if _, node, err := s.GetNode("node1"); node != nil || err != nil { + t.Fatalf("expected (nil, nil), got: (%#v, %#v)", node, err) + } + + // Create a node registration request + in := &structs.Node{ + Node: "node1", + Address: "1.1.1.1", + } + + // Ensure the node is registered in the db + if err := s.EnsureNode(1, in); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the node again + idx, out, err := s.GetNode("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Correct node was returned + if out.Node != "node1" || out.Address != "1.1.1.1" { + t.Fatalf("bad node returned: %#v", out) + } + + // Indexes are set properly + if out.CreateIndex != 1 || out.ModifyIndex != 1 { + t.Fatalf("bad node index: %#v", out) + } + if idx != 1 { + t.Fatalf("bad index: %d", idx) + } + + // Update the node registration + in.Address = "1.1.1.2" + if err := s.EnsureNode(2, in); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the node + idx, out, err = s.GetNode("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Node and indexes were updated + if out.CreateIndex != 1 || out.ModifyIndex != 2 || out.Address != "1.1.1.2" { + t.Fatalf("bad: %#v", out) + } + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + // Node upsert preserves the create index + if err := s.EnsureNode(3, in); err != nil { + t.Fatalf("err: %s", err) + } + idx, out, err = s.GetNode("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if out.CreateIndex != 1 || out.ModifyIndex != 3 || out.Address != "1.1.1.2" { + t.Fatalf("node was modified: %#v", out) + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_GetNodes(t *testing.T) { + s := testStateStore(t) + + // Listing with no results returns nil + idx, res, err := s.Nodes() + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Create some nodes in the state store + testRegisterNode(t, s, 0, "node0") + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + + // Retrieve the nodes + idx, nodes, err := s.Nodes() + if err != nil { + t.Fatalf("err: %s", err) + } + + // Highest index was returned + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + // All nodes were returned + if n := len(nodes); n != 3 { + t.Fatalf("bad node count: %d", n) + } + + // Make sure the nodes match + for i, node := range nodes { + if node.CreateIndex != uint64(i) || node.ModifyIndex != uint64(i) { + t.Fatalf("bad node index: %d, %d", node.CreateIndex, node.ModifyIndex) + } + name := fmt.Sprintf("node%d", i) + if node.Node != name { + t.Fatalf("bad: %#v", node) + } + } +} + +func BenchmarkGetNodes(b *testing.B) { + s, err := NewStateStore(nil) + if err != nil { + b.Fatalf("err: %s", err) + } + + if err := s.EnsureNode(100, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + b.Fatalf("err: %v", err) + } + if err := s.EnsureNode(101, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { + b.Fatalf("err: %v", err) + } + + for i := 0; i < b.N; i++ { + s.Nodes() + } +} + +func TestStateStore_GetNodesByMeta(t *testing.T) { + s := testStateStore(t) + + // Listing with no results returns nil + idx, res, err := s.NodesByMeta(map[string]string{"somekey": "somevalue"}) + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Create some nodes in the state store + node0 := &structs.Node{Node: "node0", Address: "127.0.0.1", Meta: map[string]string{"role": "client", "common": "1"}} + if err := s.EnsureNode(0, node0); err != nil { + t.Fatalf("err: %v", err) + } + node1 := &structs.Node{Node: "node1", Address: "127.0.0.1", Meta: map[string]string{"role": "server", "common": "1"}} + if err := s.EnsureNode(1, node1); err != nil { + t.Fatalf("err: %v", err) + } + + // Retrieve the node with role=client + idx, nodes, err := s.NodesByMeta(map[string]string{"role": "client"}) + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 1 { + t.Fatalf("bad index: %d", idx) + } + + // Only one node was returned + if n := len(nodes); n != 1 { + t.Fatalf("bad node count: %d", n) + } + + // Make sure the node is correct + if nodes[0].CreateIndex != 0 || nodes[0].ModifyIndex != 0 { + t.Fatalf("bad node index: %d, %d", nodes[0].CreateIndex, nodes[0].ModifyIndex) + } + if nodes[0].Node != "node0" { + t.Fatalf("bad: %#v", nodes[0]) + } + if !reflect.DeepEqual(nodes[0].Meta, node0.Meta) { + t.Fatalf("bad: %v != %v", nodes[0].Meta, node0.Meta) + } + + // Retrieve both nodes via their common meta field + idx, nodes, err = s.NodesByMeta(map[string]string{"common": "1"}) + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 1 { + t.Fatalf("bad index: %d", idx) + } + + // All nodes were returned + if n := len(nodes); n != 2 { + t.Fatalf("bad node count: %d", n) + } + + // Make sure the nodes match + for i, node := range nodes { + if node.CreateIndex != uint64(i) || node.ModifyIndex != uint64(i) { + t.Fatalf("bad node index: %d, %d", node.CreateIndex, node.ModifyIndex) + } + name := fmt.Sprintf("node%d", i) + if node.Node != name { + t.Fatalf("bad: %#v", node) + } + if v, ok := node.Meta["common"]; !ok || v != "1" { + t.Fatalf("bad: %v", node.Meta) + } + } +} + +func BenchmarkGetNodesByMeta(b *testing.B) { + s, err := NewStateStore(nil) + if err != nil { + b.Fatalf("err: %s", err) + } + + if err := s.EnsureNode(100, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + b.Fatalf("err: %v", err) + } + if err := s.EnsureNode(101, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { + b.Fatalf("err: %v", err) + } + + for i := 0; i < b.N; i++ { + s.Nodes() + } +} + +func TestStateStore_DeleteNode(t *testing.T) { + s := testStateStore(t) + + // Create a node and register a service and health check with it. + testRegisterNode(t, s, 0, "node1") + testRegisterService(t, s, 1, "node1", "service1") + testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) + + // Delete the node + if err := s.DeleteNode(3, "node1"); err != nil { + t.Fatalf("err: %s", err) + } + + // The node was removed + if idx, n, err := s.GetNode("node1"); err != nil || n != nil || idx != 3 { + t.Fatalf("bad: %#v %d (err: %#v)", n, idx, err) + } + + // Associated service was removed. Need to query this directly out of + // the DB to make sure it is actually gone. + tx := s.db.Txn(false) + defer tx.Abort() + services, err := tx.Get("services", "id", "node1", "service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if service := services.Next(); service != nil { + t.Fatalf("bad: %#v", service) + } + + // Associated health check was removed. + checks, err := tx.Get("checks", "id", "node1", "check1") + if err != nil { + t.Fatalf("err: %s", err) + } + if check := checks.Next(); check != nil { + t.Fatalf("bad: %#v", check) + } + + // Indexes were updated. + for _, tbl := range []string{"nodes", "services", "checks"} { + if idx := s.maxIndex(tbl); idx != 3 { + t.Fatalf("bad index: %d (%s)", idx, tbl) + } + } + + // Deleting a nonexistent node should be idempotent and not return + // an error + if err := s.DeleteNode(4, "node1"); err != nil { + t.Fatalf("err: %s", err) + } + if idx := s.maxIndex("nodes"); idx != 3 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_Node_Snapshot(t *testing.T) { + s := testStateStore(t) + + // Create some nodes in the state store. + testRegisterNode(t, s, 0, "node0") + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + + // Snapshot the nodes. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + testRegisterNode(t, s, 3, "node3") + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 2 { + t.Fatalf("bad index: %d", idx) + } + nodes, err := snap.Nodes() + if err != nil { + t.Fatalf("err: %s", err) + } + for i := 0; i < 3; i++ { + node := nodes.Next().(*structs.Node) + if node == nil { + t.Fatalf("unexpected end of nodes") + } + + if node.CreateIndex != uint64(i) || node.ModifyIndex != uint64(i) { + t.Fatalf("bad node index: %d, %d", node.CreateIndex, node.ModifyIndex) + } + if node.Node != fmt.Sprintf("node%d", i) { + t.Fatalf("bad: %#v", node) + } + } + if nodes.Next() != nil { + t.Fatalf("unexpected extra nodes") + } +} + +func TestStateStore_Node_Watches(t *testing.T) { + s := testStateStore(t) + + // Call functions that update the nodes table and make sure a watch fires + // each time. + verifyWatch(t, s.getTableWatch("nodes"), func() { + req := &structs.RegisterRequest{ + Node: "node1", + } + if err := s.EnsureRegistration(1, req); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("nodes"), func() { + node := &structs.Node{Node: "node2"} + if err := s.EnsureNode(2, node); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("nodes"), func() { + if err := s.DeleteNode(3, "node2"); err != nil { + t.Fatalf("err: %s", err) + } + }) + + // Check that a delete of a node + service + check + coordinate triggers + // all tables in one shot. + testRegisterNode(t, s, 4, "node1") + testRegisterService(t, s, 5, "node1", "service1") + testRegisterCheck(t, s, 6, "node1", "service1", "check3", structs.HealthPassing) + updates := structs.Coordinates{ + &structs.Coordinate{ + Node: "node1", + Coord: generateRandomCoordinate(), + }, + } + if err := s.CoordinateBatchUpdate(7, updates); err != nil { + t.Fatalf("err: %s", err) + } + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + verifyWatch(t, s.getTableWatch("coordinates"), func() { + if err := s.DeleteNode(7, "node1"); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + }) +} + +func TestStateStore_EnsureService(t *testing.T) { + s := testStateStore(t) + + // Fetching services for a node with none returns nil + idx, res, err := s.NodeServices("node1") + if err != nil || res != nil || idx != 0 { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Create the service registration + ns1 := &structs.NodeService{ + ID: "service1", + Service: "redis", + Tags: []string{"prod"}, + Address: "1.1.1.1", + Port: 1111, + } + + // Creating a service without a node returns an error + if err := s.EnsureService(1, "node1", ns1); err != ErrMissingNode { + t.Fatalf("expected %#v, got: %#v", ErrMissingNode, err) + } + + // Register the nodes + testRegisterNode(t, s, 0, "node1") + testRegisterNode(t, s, 1, "node2") + + // Service successfully registers into the state store + if err = s.EnsureService(10, "node1", ns1); err != nil { + t.Fatalf("err: %s", err) + } + + // Register a similar service against both nodes + ns2 := *ns1 + ns2.ID = "service2" + for _, n := range []string{"node1", "node2"} { + if err := s.EnsureService(20, n, &ns2); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Register a different service on the bad node + ns3 := *ns1 + ns3.ID = "service3" + if err := s.EnsureService(30, "node2", &ns3); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the services + idx, out, err := s.NodeServices("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 30 { + t.Fatalf("bad index: %d", idx) + } + + // Only the services for the requested node are returned + if out == nil || len(out.Services) != 2 { + t.Fatalf("bad services: %#v", out) + } + + // Results match the inserted services and have the proper indexes set + expect1 := *ns1 + expect1.CreateIndex, expect1.ModifyIndex = 10, 10 + if svc := out.Services["service1"]; !reflect.DeepEqual(&expect1, svc) { + t.Fatalf("bad: %#v", svc) + } + + expect2 := ns2 + expect2.CreateIndex, expect2.ModifyIndex = 20, 20 + if svc := out.Services["service2"]; !reflect.DeepEqual(&expect2, svc) { + t.Fatalf("bad: %#v %#v", ns2, svc) + } + + // Index tables were updated + if idx := s.maxIndex("services"); idx != 30 { + t.Fatalf("bad index: %d", idx) + } + + // Update a service registration + ns1.Address = "1.1.1.2" + if err := s.EnsureService(40, "node1", ns1); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the service again and ensure it matches + idx, out, err = s.NodeServices("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 40 { + t.Fatalf("bad index: %d", idx) + } + if out == nil || len(out.Services) != 2 { + t.Fatalf("bad: %#v", out) + } + expect1.Address = "1.1.1.2" + expect1.ModifyIndex = 40 + if svc := out.Services["service1"]; !reflect.DeepEqual(&expect1, svc) { + t.Fatalf("bad: %#v", svc) + } + + // Index tables were updated + if idx := s.maxIndex("services"); idx != 40 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_Services(t *testing.T) { + s := testStateStore(t) + + // Register several nodes and services. + testRegisterNode(t, s, 1, "node1") + ns1 := &structs.NodeService{ + ID: "service1", + Service: "redis", + Tags: []string{"prod", "master"}, + Address: "1.1.1.1", + Port: 1111, + } + if err := s.EnsureService(2, "node1", ns1); err != nil { + t.Fatalf("err: %s", err) + } + testRegisterService(t, s, 3, "node1", "dogs") + testRegisterNode(t, s, 4, "node2") + ns2 := &structs.NodeService{ + ID: "service3", + Service: "redis", + Tags: []string{"prod", "slave"}, + Address: "1.1.1.1", + Port: 1111, + } + if err := s.EnsureService(5, "node2", ns2); err != nil { + t.Fatalf("err: %s", err) + } + + // Pull all the services. + idx, services, err := s.Services() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 5 { + t.Fatalf("bad index: %d", idx) + } + + // Verify the result. We sort the lists since the order is + // non-deterministic (it's built using a map internally). + expected := structs.Services{ + "redis": []string{"prod", "master", "slave"}, + "dogs": []string{}, + } + sort.Strings(expected["redis"]) + for _, tags := range services { + sort.Strings(tags) + } + if !reflect.DeepEqual(expected, services) { + t.Fatalf("bad: %#v", services) + } +} + +func TestStateStore_ServicesByNodeMeta(t *testing.T) { + s := testStateStore(t) + + // Listing with no results returns nil + idx, res, err := s.ServicesByNodeMeta(map[string]string{"somekey": "somevalue"}) + if idx != 0 || len(res) != 0 || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Create some nodes and services in the state store + node0 := &structs.Node{Node: "node0", Address: "127.0.0.1", Meta: map[string]string{"role": "client", "common": "1"}} + if err := s.EnsureNode(0, node0); err != nil { + t.Fatalf("err: %v", err) + } + node1 := &structs.Node{Node: "node1", Address: "127.0.0.1", Meta: map[string]string{"role": "server", "common": "1"}} + if err := s.EnsureNode(1, node1); err != nil { + t.Fatalf("err: %v", err) + } + ns1 := &structs.NodeService{ + ID: "service1", + Service: "redis", + Tags: []string{"prod", "master"}, + Address: "1.1.1.1", + Port: 1111, + } + if err := s.EnsureService(2, "node0", ns1); err != nil { + t.Fatalf("err: %s", err) + } + ns2 := &structs.NodeService{ + ID: "service1", + Service: "redis", + Tags: []string{"prod", "slave"}, + Address: "1.1.1.1", + Port: 1111, + } + if err := s.EnsureService(3, "node1", ns2); err != nil { + t.Fatalf("err: %s", err) + } + + // Filter the services by the first node's meta value + idx, res, err = s.ServicesByNodeMeta(map[string]string{"role": "client"}) + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } + expected := structs.Services{ + "redis": []string{"master", "prod"}, + } + sort.Strings(res["redis"]) + if !reflect.DeepEqual(res, expected) { + t.Fatalf("bad: %v %v", res, expected) + } + + // Get all services using the common meta value + idx, res, err = s.ServicesByNodeMeta(map[string]string{"common": "1"}) + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } + expected = structs.Services{ + "redis": []string{"master", "prod", "slave"}, + } + sort.Strings(res["redis"]) + if !reflect.DeepEqual(res, expected) { + t.Fatalf("bad: %v %v", res, expected) + } +} + +func TestStateStore_ServiceNodes(t *testing.T) { + s := testStateStore(t) + + if err := s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(12, "foo", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(14, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(15, "bar", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}); err != nil { + t.Fatalf("err: %v", err) + } + + idx, nodes, err := s.ServiceNodes("db") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 16 { + t.Fatalf("bad: %v", 16) + } + if len(nodes) != 3 { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Node != "bar" { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Address != "127.0.0.2" { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].ServiceID != "db" { + t.Fatalf("bad: %v", nodes) + } + if !lib.StrContains(nodes[0].ServiceTags, "slave") { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].ServicePort != 8000 { + t.Fatalf("bad: %v", nodes) + } + + if nodes[1].Node != "bar" { + t.Fatalf("bad: %v", nodes) + } + if nodes[1].Address != "127.0.0.2" { + t.Fatalf("bad: %v", nodes) + } + if nodes[1].ServiceID != "db2" { + t.Fatalf("bad: %v", nodes) + } + if !lib.StrContains(nodes[1].ServiceTags, "slave") { + t.Fatalf("bad: %v", nodes) + } + if nodes[1].ServicePort != 8001 { + t.Fatalf("bad: %v", nodes) + } + + if nodes[2].Node != "foo" { + t.Fatalf("bad: %v", nodes) + } + if nodes[2].Address != "127.0.0.1" { + t.Fatalf("bad: %v", nodes) + } + if nodes[2].ServiceID != "db" { + t.Fatalf("bad: %v", nodes) + } + if !lib.StrContains(nodes[2].ServiceTags, "master") { + t.Fatalf("bad: %v", nodes) + } + if nodes[2].ServicePort != 8000 { + t.Fatalf("bad: %v", nodes) + } +} + +func TestStateStore_ServiceTagNodes(t *testing.T) { + s := testStateStore(t) + + if err := s.EnsureNode(15, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureNode(16, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(17, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(18, "foo", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(19, "bar", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + + idx, nodes, err := s.ServiceTagNodes("db", "master") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 19 { + t.Fatalf("bad: %v", idx) + } + if len(nodes) != 1 { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Node != "foo" { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Address != "127.0.0.1" { + t.Fatalf("bad: %v", nodes) + } + if !lib.StrContains(nodes[0].ServiceTags, "master") { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].ServicePort != 8000 { + t.Fatalf("bad: %v", nodes) + } +} + +func TestStateStore_ServiceTagNodes_MultipleTags(t *testing.T) { + s := testStateStore(t) + + if err := s.EnsureNode(15, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureNode(16, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(17, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"master", "v2"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(18, "foo", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave", "v2", "dev"}, Address: "", Port: 8001}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(19, "bar", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave", "v2"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + + idx, nodes, err := s.ServiceTagNodes("db", "master") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 19 { + t.Fatalf("bad: %v", idx) + } + if len(nodes) != 1 { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Node != "foo" { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Address != "127.0.0.1" { + t.Fatalf("bad: %v", nodes) + } + if !lib.StrContains(nodes[0].ServiceTags, "master") { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].ServicePort != 8000 { + t.Fatalf("bad: %v", nodes) + } + + idx, nodes, err = s.ServiceTagNodes("db", "v2") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 19 { + t.Fatalf("bad: %v", idx) + } + if len(nodes) != 3 { + t.Fatalf("bad: %v", nodes) + } + + idx, nodes, err = s.ServiceTagNodes("db", "dev") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 19 { + t.Fatalf("bad: %v", idx) + } + if len(nodes) != 1 { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Node != "foo" { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Address != "127.0.0.1" { + t.Fatalf("bad: %v", nodes) + } + if !lib.StrContains(nodes[0].ServiceTags, "dev") { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].ServicePort != 8001 { + t.Fatalf("bad: %v", nodes) + } +} + +func TestStateStore_DeleteService(t *testing.T) { + s := testStateStore(t) + + // Register a node with one service and a check + testRegisterNode(t, s, 1, "node1") + testRegisterService(t, s, 2, "node1", "service1") + testRegisterCheck(t, s, 3, "node1", "service1", "check1", structs.HealthPassing) + + // Delete the service + if err := s.DeleteService(4, "node1", "service1"); err != nil { + t.Fatalf("err: %s", err) + } + + // Service doesn't exist. + _, ns, err := s.NodeServices("node1") + if err != nil || ns == nil || len(ns.Services) != 0 { + t.Fatalf("bad: %#v (err: %#v)", ns, err) + } + + // Check doesn't exist. Check using the raw DB so we can test + // that it actually is removed in the state store. + tx := s.db.Txn(false) + defer tx.Abort() + check, err := tx.First("checks", "id", "node1", "check1") + if err != nil || check != nil { + t.Fatalf("bad: %#v (err: %s)", check, err) + } + + // Index tables were updated + if idx := s.maxIndex("services"); idx != 4 { + t.Fatalf("bad index: %d", idx) + } + if idx := s.maxIndex("checks"); idx != 4 { + t.Fatalf("bad index: %d", idx) + } + + // Deleting a nonexistent service should be idempotent and not return an + // error + if err := s.DeleteService(5, "node1", "service1"); err != nil { + t.Fatalf("err: %s", err) + } + if idx := s.maxIndex("services"); idx != 4 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_Service_Snapshot(t *testing.T) { + s := testStateStore(t) + + // Register a node with two services. + testRegisterNode(t, s, 0, "node1") + ns := []*structs.NodeService{ + &structs.NodeService{ + ID: "service1", + Service: "redis", + Tags: []string{"prod"}, + Address: "1.1.1.1", + Port: 1111, + }, + &structs.NodeService{ + ID: "service2", + Service: "nomad", + Tags: []string{"dev"}, + Address: "1.1.1.2", + Port: 1112, + }, + } + for i, svc := range ns { + if err := s.EnsureService(uint64(i+1), "node1", svc); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Create a second node/service to make sure node filtering works. This + // will affect the index but not the dump. + testRegisterNode(t, s, 3, "node2") + testRegisterService(t, s, 4, "node2", "service2") + + // Snapshot the service. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + testRegisterService(t, s, 5, "node2", "service3") + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 4 { + t.Fatalf("bad index: %d", idx) + } + services, err := snap.Services("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + for i := 0; i < len(ns); i++ { + svc := services.Next().(*structs.ServiceNode) + if svc == nil { + t.Fatalf("unexpected end of services") + } + + ns[i].CreateIndex, ns[i].ModifyIndex = uint64(i+1), uint64(i+1) + if !reflect.DeepEqual(ns[i], svc.ToNodeService()) { + t.Fatalf("bad: %#v != %#v", svc, ns[i]) + } + } + if services.Next() != nil { + t.Fatalf("unexpected extra services") + } +} + +func TestStateStore_Service_Watches(t *testing.T) { + s := testStateStore(t) + + testRegisterNode(t, s, 0, "node1") + ns := &structs.NodeService{ + ID: "service2", + Service: "nomad", + Address: "1.1.1.2", + Port: 8000, + } + + // Call functions that update the services table and make sure a watch + // fires each time. + verifyWatch(t, s.getTableWatch("services"), func() { + if err := s.EnsureService(2, "node1", ns); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("services"), func() { + if err := s.DeleteService(3, "node1", "service2"); err != nil { + t.Fatalf("err: %s", err) + } + }) + + // Check that a delete of a service + check triggers both tables in one + // shot. + testRegisterService(t, s, 4, "node1", "service1") + testRegisterCheck(t, s, 5, "node1", "service1", "check3", structs.HealthPassing) + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.DeleteService(6, "node1", "service1"); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) +} + +func TestStateStore_EnsureCheck(t *testing.T) { + s := testStateStore(t) + + // Create a check associated with the node + check := &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Name: "redis check", + Status: structs.HealthPassing, + Notes: "test check", + Output: "aaa", + ServiceID: "service1", + ServiceName: "redis", + } + + // Creating a check without a node returns error + if err := s.EnsureCheck(1, check); err != ErrMissingNode { + t.Fatalf("expected %#v, got: %#v", ErrMissingNode, err) + } + + // Register the node + testRegisterNode(t, s, 1, "node1") + + // Creating a check with a bad services returns error + if err := s.EnsureCheck(1, check); err != ErrMissingService { + t.Fatalf("expected: %#v, got: %#v", ErrMissingService, err) + } + + // Register the service + testRegisterService(t, s, 2, "node1", "service1") + + // Inserting the check with the prerequisites succeeds + if err := s.EnsureCheck(3, check); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the check and make sure it matches + idx, checks, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } + if len(checks) != 1 { + t.Fatalf("wrong number of checks: %d", len(checks)) + } + if !reflect.DeepEqual(checks[0], check) { + t.Fatalf("bad: %#v", checks[0]) + } + + // Modify the health check + check.Output = "bbb" + if err := s.EnsureCheck(4, check); err != nil { + t.Fatalf("err: %s", err) + } + + // Check that we successfully updated + idx, checks, err = s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + if len(checks) != 1 { + t.Fatalf("wrong number of checks: %d", len(checks)) + } + if checks[0].Output != "bbb" { + t.Fatalf("wrong check output: %#v", checks[0]) + } + if checks[0].CreateIndex != 3 || checks[0].ModifyIndex != 4 { + t.Fatalf("bad index: %#v", checks[0]) + } + + // Index tables were updated + if idx := s.maxIndex("checks"); idx != 4 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_EnsureCheck_defaultStatus(t *testing.T) { + s := testStateStore(t) + + // Register a node + testRegisterNode(t, s, 1, "node1") + + // Create and register a check with no health status + check := &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Status: "", + } + if err := s.EnsureCheck(2, check); err != nil { + t.Fatalf("err: %s", err) + } + + // Get the check again + _, result, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Check that the status was set to the proper default + if len(result) != 1 || result[0].Status != structs.HealthCritical { + t.Fatalf("bad: %#v", result) + } +} + +func TestStateStore_NodeChecks(t *testing.T) { + s := testStateStore(t) + + // Create the first node and service with some checks + testRegisterNode(t, s, 0, "node1") + testRegisterService(t, s, 1, "node1", "service1") + testRegisterCheck(t, s, 2, "node1", "service1", "check1", structs.HealthPassing) + testRegisterCheck(t, s, 3, "node1", "service1", "check2", structs.HealthPassing) + + // Create a second node/service with a different set of checks + testRegisterNode(t, s, 4, "node2") + testRegisterService(t, s, 5, "node2", "service2") + testRegisterCheck(t, s, 6, "node2", "service2", "check3", structs.HealthPassing) + + // Try querying for all checks associated with node1 + idx, checks, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + if len(checks) != 2 || checks[0].CheckID != "check1" || checks[1].CheckID != "check2" { + t.Fatalf("bad checks: %#v", checks) + } + + // Try querying for all checks associated with node2 + idx, checks, err = s.NodeChecks("node2") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + if len(checks) != 1 || checks[0].CheckID != "check3" { + t.Fatalf("bad checks: %#v", checks) + } +} + +func TestStateStore_ServiceChecks(t *testing.T) { + s := testStateStore(t) + + // Create the first node and service with some checks + testRegisterNode(t, s, 0, "node1") + testRegisterService(t, s, 1, "node1", "service1") + testRegisterCheck(t, s, 2, "node1", "service1", "check1", structs.HealthPassing) + testRegisterCheck(t, s, 3, "node1", "service1", "check2", structs.HealthPassing) + + // Create a second node/service with a different set of checks + testRegisterNode(t, s, 4, "node2") + testRegisterService(t, s, 5, "node2", "service2") + testRegisterCheck(t, s, 6, "node2", "service2", "check3", structs.HealthPassing) + + // Try querying for all checks associated with service1 + idx, checks, err := s.ServiceChecks("service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + if len(checks) != 2 || checks[0].CheckID != "check1" || checks[1].CheckID != "check2" { + t.Fatalf("bad checks: %#v", checks) + } +} + +func TestStateStore_ChecksInState(t *testing.T) { + s := testStateStore(t) + + // Querying with no results returns nil + idx, res, err := s.ChecksInState(structs.HealthPassing) + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Register a node with checks in varied states + testRegisterNode(t, s, 0, "node1") + testRegisterCheck(t, s, 1, "node1", "", "check1", structs.HealthPassing) + testRegisterCheck(t, s, 2, "node1", "", "check2", structs.HealthCritical) + testRegisterCheck(t, s, 3, "node1", "", "check3", structs.HealthPassing) + + // Query the state store for passing checks. + _, checks, err := s.ChecksInState(structs.HealthPassing) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Make sure we only get the checks which match the state + if n := len(checks); n != 2 { + t.Fatalf("expected 2 checks, got: %d", n) + } + if checks[0].CheckID != "check1" || checks[1].CheckID != "check3" { + t.Fatalf("bad: %#v", checks) + } + + // HealthAny just returns everything. + _, checks, err = s.ChecksInState(structs.HealthAny) + if err != nil { + t.Fatalf("err: %s", err) + } + if n := len(checks); n != 3 { + t.Fatalf("expected 3 checks, got: %d", n) + } +} + +func TestStateStore_DeleteCheck(t *testing.T) { + s := testStateStore(t) + + // Register a node and a node-level health check + testRegisterNode(t, s, 1, "node1") + testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) + + // Delete the check + if err := s.DeleteCheck(3, "node1", "check1"); err != nil { + t.Fatalf("err: %s", err) + } + + // Check is gone + _, checks, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(checks) != 0 { + t.Fatalf("bad: %#v", checks) + } + + // Index tables were updated + if idx := s.maxIndex("checks"); idx != 3 { + t.Fatalf("bad index: %d", idx) + } + + // Deleting a nonexistent check should be idempotent and not return an + // error + if err := s.DeleteCheck(4, "node1", "check1"); err != nil { + t.Fatalf("err: %s", err) + } + if idx := s.maxIndex("checks"); idx != 3 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_CheckServiceNodes(t *testing.T) { + s := testStateStore(t) + + // Querying with no matches gives an empty response + idx, res, err := s.CheckServiceNodes("service1") + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Register some nodes + testRegisterNode(t, s, 0, "node1") + testRegisterNode(t, s, 1, "node2") + + // Register node-level checks. These should not be returned + // in the final result. + testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) + testRegisterCheck(t, s, 3, "node2", "", "check2", structs.HealthPassing) + + // Register a service against the nodes + testRegisterService(t, s, 4, "node1", "service1") + testRegisterService(t, s, 5, "node2", "service2") + + // Register checks against the services + testRegisterCheck(t, s, 6, "node1", "service1", "check3", structs.HealthPassing) + testRegisterCheck(t, s, 7, "node2", "service2", "check4", structs.HealthPassing) + + // Query the state store for nodes and checks which + // have been registered with a specific service. + idx, results, err := s.CheckServiceNodes("service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // Make sure we get the expected result (service check + node check) + if n := len(results); n != 1 { + t.Fatalf("expected 1 result, got: %d", n) + } + csn := results[0] + if csn.Node == nil || csn.Service == nil || len(csn.Checks) != 2 { + t.Fatalf("bad output: %#v", csn) + } + + // Node updates alter the returned index + testRegisterNode(t, s, 8, "node1") + idx, results, err = s.CheckServiceNodes("service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 8 { + t.Fatalf("bad index: %d", idx) + } + + // Service updates alter the returned index + testRegisterService(t, s, 9, "node1", "service1") + idx, results, err = s.CheckServiceNodes("service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 9 { + t.Fatalf("bad index: %d", idx) + } + + // Check updates alter the returned index + testRegisterCheck(t, s, 10, "node1", "service1", "check1", structs.HealthCritical) + idx, results, err = s.CheckServiceNodes("service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 10 { + t.Fatalf("bad index: %d", idx) + } +} + +func BenchmarkCheckServiceNodes(b *testing.B) { + s, err := NewStateStore(nil) + if err != nil { + b.Fatalf("err: %s", err) + } + + if err := s.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + b.Fatalf("err: %v", err) + } + if err := s.EnsureService(2, "foo", &structs.NodeService{ID: "db1", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { + b.Fatalf("err: %v", err) + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "db", + Name: "can connect", + Status: structs.HealthPassing, + ServiceID: "db1", + } + if err := s.EnsureCheck(3, check); err != nil { + b.Fatalf("err: %v", err) + } + check = &structs.HealthCheck{ + Node: "foo", + CheckID: "check1", + Name: "check1", + Status: structs.HealthPassing, + } + if err := s.EnsureCheck(4, check); err != nil { + b.Fatalf("err: %v", err) + } + + for i := 0; i < b.N; i++ { + s.CheckServiceNodes("db") + } +} + +func TestStateStore_CheckServiceTagNodes(t *testing.T) { + s := testStateStore(t) + + if err := s.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + if err := s.EnsureService(2, "foo", &structs.NodeService{ID: "db1", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "db", + Name: "can connect", + Status: structs.HealthPassing, + ServiceID: "db1", + } + if err := s.EnsureCheck(3, check); err != nil { + t.Fatalf("err: %v", err) + } + check = &structs.HealthCheck{ + Node: "foo", + CheckID: "check1", + Name: "another check", + Status: structs.HealthPassing, + } + if err := s.EnsureCheck(4, check); err != nil { + t.Fatalf("err: %v", err) + } + + idx, nodes, err := s.CheckServiceTagNodes("db", "master") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad: %v", idx) + } + if len(nodes) != 1 { + t.Fatalf("Bad: %v", nodes) + } + if nodes[0].Node.Node != "foo" { + t.Fatalf("Bad: %v", nodes[0]) + } + if nodes[0].Service.ID != "db1" { + t.Fatalf("Bad: %v", nodes[0]) + } + if len(nodes[0].Checks) != 2 { + t.Fatalf("Bad: %v", nodes[0]) + } + if nodes[0].Checks[0].CheckID != "check1" { + t.Fatalf("Bad: %v", nodes[0]) + } + if nodes[0].Checks[1].CheckID != "db" { + t.Fatalf("Bad: %v", nodes[0]) + } +} + +func TestStateStore_Check_Snapshot(t *testing.T) { + s := testStateStore(t) + + // Create a node, a service, and a service check as well as a node check. + testRegisterNode(t, s, 0, "node1") + testRegisterService(t, s, 1, "node1", "service1") + checks := structs.HealthChecks{ + &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Name: "node check", + Status: structs.HealthPassing, + }, + &structs.HealthCheck{ + Node: "node1", + CheckID: "check2", + Name: "service check", + Status: structs.HealthCritical, + ServiceID: "service1", + }, + } + for i, hc := range checks { + if err := s.EnsureCheck(uint64(i+1), hc); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Create a second node/service to make sure node filtering works. This + // will affect the index but not the dump. + testRegisterNode(t, s, 3, "node2") + testRegisterService(t, s, 4, "node2", "service2") + testRegisterCheck(t, s, 5, "node2", "service2", "check3", structs.HealthPassing) + + // Snapshot the checks. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + testRegisterCheck(t, s, 6, "node2", "service2", "check4", structs.HealthPassing) + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 5 { + t.Fatalf("bad index: %d", idx) + } + iter, err := snap.Checks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + for i := 0; i < len(checks); i++ { + check := iter.Next().(*structs.HealthCheck) + if check == nil { + t.Fatalf("unexpected end of checks") + } + + checks[i].CreateIndex, checks[i].ModifyIndex = uint64(i+1), uint64(i+1) + if !reflect.DeepEqual(check, checks[i]) { + t.Fatalf("bad: %#v != %#v", check, checks[i]) + } + } + if iter.Next() != nil { + t.Fatalf("unexpected extra checks") + } +} + +func TestStateStore_Check_Watches(t *testing.T) { + s := testStateStore(t) + + testRegisterNode(t, s, 0, "node1") + hc := &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Status: structs.HealthPassing, + } + + // Call functions that update the checks table and make sure a watch fires + // each time. + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.EnsureCheck(1, hc); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("checks"), func() { + hc.Status = structs.HealthCritical + if err := s.EnsureCheck(2, hc); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.DeleteCheck(3, "node1", "check1"); err != nil { + t.Fatalf("err: %s", err) + } + }) +} + +func TestStateStore_NodeInfo_NodeDump(t *testing.T) { + s := testStateStore(t) + + // Generating a node dump that matches nothing returns empty + idx, dump, err := s.NodeInfo("node1") + if idx != 0 || dump != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, dump, err) + } + idx, dump, err = s.NodeDump() + if idx != 0 || dump != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, dump, err) + } + + // Register some nodes + testRegisterNode(t, s, 0, "node1") + testRegisterNode(t, s, 1, "node2") + + // Register services against them + testRegisterService(t, s, 2, "node1", "service1") + testRegisterService(t, s, 3, "node1", "service2") + testRegisterService(t, s, 4, "node2", "service1") + testRegisterService(t, s, 5, "node2", "service2") + + // Register service-level checks + testRegisterCheck(t, s, 6, "node1", "service1", "check1", structs.HealthPassing) + testRegisterCheck(t, s, 7, "node2", "service1", "check1", structs.HealthPassing) + + // Register node-level checks + testRegisterCheck(t, s, 8, "node1", "", "check2", structs.HealthPassing) + testRegisterCheck(t, s, 9, "node2", "", "check2", structs.HealthPassing) + + // Check that our result matches what we expect. + expect := structs.NodeDump{ + &structs.NodeInfo{ + Node: "node1", + Checks: structs.HealthChecks{ + &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + ServiceID: "service1", + ServiceName: "service1", + Status: structs.HealthPassing, + RaftIndex: structs.RaftIndex{ + CreateIndex: 6, + ModifyIndex: 6, + }, + }, + &structs.HealthCheck{ + Node: "node1", + CheckID: "check2", + ServiceID: "", + ServiceName: "", + Status: structs.HealthPassing, + RaftIndex: structs.RaftIndex{ + CreateIndex: 8, + ModifyIndex: 8, + }, + }, + }, + Services: []*structs.NodeService{ + &structs.NodeService{ + ID: "service1", + Service: "service1", + Address: "1.1.1.1", + Port: 1111, + RaftIndex: structs.RaftIndex{ + CreateIndex: 2, + ModifyIndex: 2, + }, + }, + &structs.NodeService{ + ID: "service2", + Service: "service2", + Address: "1.1.1.1", + Port: 1111, + RaftIndex: structs.RaftIndex{ + CreateIndex: 3, + ModifyIndex: 3, + }, + }, + }, + }, + &structs.NodeInfo{ + Node: "node2", + Checks: structs.HealthChecks{ + &structs.HealthCheck{ + Node: "node2", + CheckID: "check1", + ServiceID: "service1", + ServiceName: "service1", + Status: structs.HealthPassing, + RaftIndex: structs.RaftIndex{ + CreateIndex: 7, + ModifyIndex: 7, + }, + }, + &structs.HealthCheck{ + Node: "node2", + CheckID: "check2", + ServiceID: "", + ServiceName: "", + Status: structs.HealthPassing, + RaftIndex: structs.RaftIndex{ + CreateIndex: 9, + ModifyIndex: 9, + }, + }, + }, + Services: []*structs.NodeService{ + &structs.NodeService{ + ID: "service1", + Service: "service1", + Address: "1.1.1.1", + Port: 1111, + RaftIndex: structs.RaftIndex{ + CreateIndex: 4, + ModifyIndex: 4, + }, + }, + &structs.NodeService{ + ID: "service2", + Service: "service2", + Address: "1.1.1.1", + Port: 1111, + RaftIndex: structs.RaftIndex{ + CreateIndex: 5, + ModifyIndex: 5, + }, + }, + }, + }, + } + + // Get a dump of just a single node + idx, dump, err = s.NodeInfo("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 9 { + t.Fatalf("bad index: %d", idx) + } + if len(dump) != 1 || !reflect.DeepEqual(dump[0], expect[0]) { + t.Fatalf("bad: %#v", dump) + } + + // Generate a dump of all the nodes + idx, dump, err = s.NodeDump() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 9 { + t.Fatalf("bad index: %d", 9) + } + if !reflect.DeepEqual(dump, expect) { + t.Fatalf("bad: %#v", dump[0].Services[0]) + } +} diff --git a/consul/state/coordinate.go b/consul/state/coordinate.go new file mode 100644 index 0000000000..376d02b6c5 --- /dev/null +++ b/consul/state/coordinate.go @@ -0,0 +1,117 @@ +package state + +import ( + "fmt" + + "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" + "github.com/hashicorp/serf/coordinate" +) + +// Coordinates is used to pull all the coordinates from the snapshot. +func (s *StateSnapshot) Coordinates() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("coordinates", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +// Coordinates is used when restoring from a snapshot. For general inserts, use +// CoordinateBatchUpdate. We do less vetting of the updates here because they +// already got checked on the way in during a batch update. +func (s *StateRestore) Coordinates(idx uint64, updates structs.Coordinates) error { + for _, update := range updates { + if err := s.tx.Insert("coordinates", update); err != nil { + return fmt.Errorf("failed restoring coordinate: %s", err) + } + } + + if err := indexUpdateMaxTxn(s.tx, idx, "coordinates"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + s.watches.Arm("coordinates") + return nil +} + +// CoordinateGetRaw queries for the coordinate of the given node. This is an +// unusual state store method because it just returns the raw coordinate or +// nil, none of the Raft or node information is returned. This hits the 90% +// internal-to-Consul use case for this data, and this isn't exposed via an +// endpoint, so it doesn't matter that the Raft info isn't available. +func (s *StateStore) CoordinateGetRaw(node string) (*coordinate.Coordinate, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Pull the full coordinate entry. + coord, err := tx.First("coordinates", "id", node) + if err != nil { + return nil, fmt.Errorf("failed coordinate lookup: %s", err) + } + + // Pick out just the raw coordinate. + if coord != nil { + return coord.(*structs.Coordinate).Coord, nil + } + return nil, nil +} + +// Coordinates queries for all nodes with coordinates. +func (s *StateStore) Coordinates() (uint64, structs.Coordinates, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("Coordinates")...) + + // Pull all the coordinates. + coords, err := tx.Get("coordinates", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed coordinate lookup: %s", err) + } + var results structs.Coordinates + for coord := coords.Next(); coord != nil; coord = coords.Next() { + results = append(results, coord.(*structs.Coordinate)) + } + return idx, results, nil +} + +// CoordinateBatchUpdate processes a batch of coordinate updates and applies +// them in a single transaction. +func (s *StateStore) CoordinateBatchUpdate(idx uint64, updates structs.Coordinates) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Upsert the coordinates. + for _, update := range updates { + // Since the cleanup of coordinates is tied to deletion of + // nodes, we silently drop any updates for nodes that we don't + // know about. This might be possible during normal operation + // if we happen to get a coordinate update for a node that + // hasn't been able to add itself to the catalog yet. Since we + // don't carefully sequence this, and since it will fix itself + // on the next coordinate update from that node, we don't return + // an error or log anything. + node, err := tx.First("nodes", "id", update.Node) + if err != nil { + return fmt.Errorf("failed node lookup: %s", err) + } + if node == nil { + continue + } + + if err := tx.Insert("coordinates", update); err != nil { + return fmt.Errorf("failed inserting coordinate: %s", err) + } + } + + // Update the index. + if err := tx.Insert("index", &IndexEntry{"coordinates", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Defer(func() { s.tableWatches["coordinates"].Notify() }) + tx.Commit() + return nil +} diff --git a/consul/state/coordinate_test.go b/consul/state/coordinate_test.go new file mode 100644 index 0000000000..1998333845 --- /dev/null +++ b/consul/state/coordinate_test.go @@ -0,0 +1,298 @@ +package state + +import ( + "math/rand" + "reflect" + "testing" + + "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/serf/coordinate" +) + +// generateRandomCoordinate creates a random coordinate. This mucks with the +// underlying structure directly, so it's not really useful for any particular +// position in the network, but it's a good payload to send through to make +// sure things come out the other side or get stored correctly. +func generateRandomCoordinate() *coordinate.Coordinate { + config := coordinate.DefaultConfig() + coord := coordinate.NewCoordinate(config) + for i := range coord.Vec { + coord.Vec[i] = rand.NormFloat64() + } + coord.Error = rand.NormFloat64() + coord.Adjustment = rand.NormFloat64() + return coord +} + +func TestStateStore_Coordinate_Updates(t *testing.T) { + s := testStateStore(t) + + // Make sure the coordinates list starts out empty, and that a query for + // a raw coordinate for a nonexistent node doesn't do anything bad. + idx, coords, err := s.Coordinates() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 0 { + t.Fatalf("bad index: %d", idx) + } + if coords != nil { + t.Fatalf("bad: %#v", coords) + } + coord, err := s.CoordinateGetRaw("nope") + if err != nil { + t.Fatalf("err: %s", err) + } + if coord != nil { + t.Fatalf("bad: %#v", coord) + } + + // Make an update for nodes that don't exist and make sure they get + // ignored. + updates := structs.Coordinates{ + &structs.Coordinate{ + Node: "node1", + Coord: generateRandomCoordinate(), + }, + &structs.Coordinate{ + Node: "node2", + Coord: generateRandomCoordinate(), + }, + } + if err := s.CoordinateBatchUpdate(1, updates); err != nil { + t.Fatalf("err: %s", err) + } + + // Should still be empty, though applying an empty batch does bump + // the table index. + idx, coords, err = s.Coordinates() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 1 { + t.Fatalf("bad index: %d", idx) + } + if coords != nil { + t.Fatalf("bad: %#v", coords) + } + + // Register the nodes then do the update again. + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + if err := s.CoordinateBatchUpdate(3, updates); err != nil { + t.Fatalf("err: %s", err) + } + + // Should go through now. + idx, coords, err = s.Coordinates() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } + if !reflect.DeepEqual(coords, updates) { + t.Fatalf("bad: %#v", coords) + } + + // Also verify the raw coordinate interface. + for _, update := range updates { + coord, err := s.CoordinateGetRaw(update.Node) + if err != nil { + t.Fatalf("err: %s", err) + } + if !reflect.DeepEqual(coord, update.Coord) { + t.Fatalf("bad: %#v", coord) + } + } + + // Update the coordinate for one of the nodes. + updates[1].Coord = generateRandomCoordinate() + if err := s.CoordinateBatchUpdate(4, updates); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify it got applied. + idx, coords, err = s.Coordinates() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + if !reflect.DeepEqual(coords, updates) { + t.Fatalf("bad: %#v", coords) + } + + // And check the raw coordinate version of the same thing. + for _, update := range updates { + coord, err := s.CoordinateGetRaw(update.Node) + if err != nil { + t.Fatalf("err: %s", err) + } + if !reflect.DeepEqual(coord, update.Coord) { + t.Fatalf("bad: %#v", coord) + } + } +} + +func TestStateStore_Coordinate_Cleanup(t *testing.T) { + s := testStateStore(t) + + // Register a node and update its coordinate. + testRegisterNode(t, s, 1, "node1") + updates := structs.Coordinates{ + &structs.Coordinate{ + Node: "node1", + Coord: generateRandomCoordinate(), + }, + } + if err := s.CoordinateBatchUpdate(2, updates); err != nil { + t.Fatalf("err: %s", err) + } + + // Make sure it's in there. + coord, err := s.CoordinateGetRaw("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if !reflect.DeepEqual(coord, updates[0].Coord) { + t.Fatalf("bad: %#v", coord) + } + + // Now delete the node. + if err := s.DeleteNode(3, "node1"); err != nil { + t.Fatalf("err: %s", err) + } + + // Make sure the coordinate is gone. + coord, err = s.CoordinateGetRaw("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if coord != nil { + t.Fatalf("bad: %#v", coord) + } + + // Make sure the index got updated. + idx, coords, err := s.Coordinates() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } + if coords != nil { + t.Fatalf("bad: %#v", coords) + } +} + +func TestStateStore_Coordinate_Snapshot_Restore(t *testing.T) { + s := testStateStore(t) + + // Register two nodes and update their coordinates. + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + updates := structs.Coordinates{ + &structs.Coordinate{ + Node: "node1", + Coord: generateRandomCoordinate(), + }, + &structs.Coordinate{ + Node: "node2", + Coord: generateRandomCoordinate(), + }, + } + if err := s.CoordinateBatchUpdate(3, updates); err != nil { + t.Fatalf("err: %s", err) + } + + // Snapshot the coordinates. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + trash := structs.Coordinates{ + &structs.Coordinate{ + Node: "node1", + Coord: generateRandomCoordinate(), + }, + &structs.Coordinate{ + Node: "node2", + Coord: generateRandomCoordinate(), + }, + } + if err := s.CoordinateBatchUpdate(4, trash); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 3 { + t.Fatalf("bad index: %d", idx) + } + iter, err := snap.Coordinates() + if err != nil { + t.Fatalf("err: %s", err) + } + var dump structs.Coordinates + for coord := iter.Next(); coord != nil; coord = iter.Next() { + dump = append(dump, coord.(*structs.Coordinate)) + } + if !reflect.DeepEqual(dump, updates) { + t.Fatalf("bad: %#v", dump) + } + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + restore := s.Restore() + if err := restore.Coordinates(5, dump); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + + // Read the restored coordinates back out and verify that they match. + idx, res, err := s.Coordinates() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 5 { + t.Fatalf("bad index: %d", idx) + } + if !reflect.DeepEqual(res, updates) { + t.Fatalf("bad: %#v", res) + } + + // Check that the index was updated (note that it got passed + // in during the restore). + if idx := s.maxIndex("coordinates"); idx != 5 { + t.Fatalf("bad index: %d", idx) + } + }() + +} + +func TestStateStore_Coordinate_Watches(t *testing.T) { + s := testStateStore(t) + + testRegisterNode(t, s, 1, "node1") + + // Call functions that update the coordinates table and make sure a watch fires + // each time. + verifyWatch(t, s.getTableWatch("coordinates"), func() { + updates := structs.Coordinates{ + &structs.Coordinate{ + Node: "node1", + Coord: generateRandomCoordinate(), + }, + } + if err := s.CoordinateBatchUpdate(2, updates); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("coordinates"), func() { + if err := s.DeleteNode(3, "node1"); err != nil { + t.Fatalf("err: %s", err) + } + }) +} diff --git a/consul/state/session.go b/consul/state/session.go new file mode 100644 index 0000000000..08e6c521df --- /dev/null +++ b/consul/state/session.go @@ -0,0 +1,345 @@ +package state + +import ( + "fmt" + "time" + + "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" +) + +// Sessions is used to pull the full list of sessions for use during snapshots. +func (s *StateSnapshot) Sessions() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("sessions", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +// Session is used when restoring from a snapshot. For general inserts, use +// SessionCreate. +func (s *StateRestore) Session(sess *structs.Session) error { + // Insert the session. + if err := s.tx.Insert("sessions", sess); err != nil { + return fmt.Errorf("failed inserting session: %s", err) + } + + // Insert the check mappings. + for _, checkID := range sess.Checks { + mapping := &sessionCheck{ + Node: sess.Node, + CheckID: checkID, + Session: sess.ID, + } + if err := s.tx.Insert("session_checks", mapping); err != nil { + return fmt.Errorf("failed inserting session check mapping: %s", err) + } + } + + // Update the index. + if err := indexUpdateMaxTxn(s.tx, sess.ModifyIndex, "sessions"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + s.watches.Arm("sessions") + return nil +} + +// SessionCreate is used to register a new session in the state store. +func (s *StateStore) SessionCreate(idx uint64, sess *structs.Session) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // This code is technically able to (incorrectly) update an existing + // session but we never do that in practice. The upstream endpoint code + // always adds a unique ID when doing a create operation so we never hit + // an existing session again. It isn't worth the overhead to verify + // that here, but it's worth noting that we should never do this in the + // future. + + // Call the session creation + if err := s.sessionCreateTxn(tx, idx, sess); err != nil { + return err + } + + tx.Commit() + return nil +} + +// sessionCreateTxn is the inner method used for creating session entries in +// an open transaction. Any health checks registered with the session will be +// checked for failing status. Returns any error encountered. +func (s *StateStore) sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Session) error { + // Check that we have a session ID + if sess.ID == "" { + return ErrMissingSessionID + } + + // Verify the session behavior is valid + switch sess.Behavior { + case "": + // Release by default to preserve backwards compatibility + sess.Behavior = structs.SessionKeysRelease + case structs.SessionKeysRelease: + case structs.SessionKeysDelete: + default: + return fmt.Errorf("Invalid session behavior: %s", sess.Behavior) + } + + // Assign the indexes. ModifyIndex likely will not be used but + // we set it here anyways for sanity. + sess.CreateIndex = idx + sess.ModifyIndex = idx + + // Check that the node exists + node, err := tx.First("nodes", "id", sess.Node) + if err != nil { + return fmt.Errorf("failed node lookup: %s", err) + } + if node == nil { + return ErrMissingNode + } + + // Go over the session checks and ensure they exist. + for _, checkID := range sess.Checks { + check, err := tx.First("checks", "id", sess.Node, string(checkID)) + if err != nil { + return fmt.Errorf("failed check lookup: %s", err) + } + if check == nil { + return fmt.Errorf("Missing check '%s' registration", checkID) + } + + // Check that the check is not in critical state + status := check.(*structs.HealthCheck).Status + if status == structs.HealthCritical { + return fmt.Errorf("Check '%s' is in %s state", checkID, status) + } + } + + // Insert the session + if err := tx.Insert("sessions", sess); err != nil { + return fmt.Errorf("failed inserting session: %s", err) + } + + // Insert the check mappings + for _, checkID := range sess.Checks { + mapping := &sessionCheck{ + Node: sess.Node, + CheckID: checkID, + Session: sess.ID, + } + if err := tx.Insert("session_checks", mapping); err != nil { + return fmt.Errorf("failed inserting session check mapping: %s", err) + } + } + + // Update the index + if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Defer(func() { s.tableWatches["sessions"].Notify() }) + return nil +} + +// SessionGet is used to retrieve an active session from the state store. +func (s *StateStore) SessionGet(sessionID string) (uint64, *structs.Session, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("SessionGet")...) + + // Look up the session by its ID + session, err := tx.First("sessions", "id", sessionID) + if err != nil { + return 0, nil, fmt.Errorf("failed session lookup: %s", err) + } + if session != nil { + return idx, session.(*structs.Session), nil + } + return idx, nil, nil +} + +// SessionList returns a slice containing all of the active sessions. +func (s *StateStore) SessionList() (uint64, structs.Sessions, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("SessionList")...) + + // Query all of the active sessions. + sessions, err := tx.Get("sessions", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed session lookup: %s", err) + } + + // Go over the sessions and create a slice of them. + var result structs.Sessions + for session := sessions.Next(); session != nil; session = sessions.Next() { + result = append(result, session.(*structs.Session)) + } + return idx, result, nil +} + +// NodeSessions returns a set of active sessions associated +// with the given node ID. The returned index is the highest +// index seen from the result set. +func (s *StateStore) NodeSessions(nodeID string) (uint64, structs.Sessions, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("NodeSessions")...) + + // Get all of the sessions which belong to the node + sessions, err := tx.Get("sessions", "node", nodeID) + if err != nil { + return 0, nil, fmt.Errorf("failed session lookup: %s", err) + } + + // Go over all of the sessions and return them as a slice + var result structs.Sessions + for session := sessions.Next(); session != nil; session = sessions.Next() { + result = append(result, session.(*structs.Session)) + } + return idx, result, nil +} + +// SessionDestroy is used to remove an active session. This will +// implicitly invalidate the session and invoke the specified +// session destroy behavior. +func (s *StateStore) SessionDestroy(idx uint64, sessionID string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the session deletion. + watches := NewDumbWatchManager(s.tableWatches) + if err := s.deleteSessionTxn(tx, idx, watches, sessionID); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// deleteSessionTxn is the inner method, which is used to do the actual +// session deletion and handle session invalidation, watch triggers, etc. +func (s *StateStore) deleteSessionTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, sessionID string) error { + // Look up the session. + sess, err := tx.First("sessions", "id", sessionID) + if err != nil { + return fmt.Errorf("failed session lookup: %s", err) + } + if sess == nil { + return nil + } + + // Delete the session and write the new index. + if err := tx.Delete("sessions", sess); err != nil { + return fmt.Errorf("failed deleting session: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + // Enforce the max lock delay. + session := sess.(*structs.Session) + delay := session.LockDelay + if delay > structs.MaxLockDelay { + delay = structs.MaxLockDelay + } + + // Snag the current now time so that all the expirations get calculated + // the same way. + now := time.Now() + + // Get an iterator over all of the keys with the given session. + entries, err := tx.Get("kvs", "session", sessionID) + if err != nil { + return fmt.Errorf("failed kvs lookup: %s", err) + } + var kvs []interface{} + for entry := entries.Next(); entry != nil; entry = entries.Next() { + kvs = append(kvs, entry) + } + + // Invalidate any held locks. + switch session.Behavior { + case structs.SessionKeysRelease: + for _, obj := range kvs { + // Note that we clone here since we are modifying the + // returned object and want to make sure our set op + // respects the transaction we are in. + e := obj.(*structs.DirEntry).Clone() + e.Session = "" + if err := s.kvsSetTxn(tx, idx, e, true); err != nil { + return fmt.Errorf("failed kvs update: %s", err) + } + + // Apply the lock delay if present. + if delay > 0 { + s.lockDelay.SetExpiration(e.Key, now, delay) + } + } + case structs.SessionKeysDelete: + for _, obj := range kvs { + e := obj.(*structs.DirEntry) + if err := s.kvsDeleteTxn(tx, idx, e.Key); err != nil { + return fmt.Errorf("failed kvs delete: %s", err) + } + + // Apply the lock delay if present. + if delay > 0 { + s.lockDelay.SetExpiration(e.Key, now, delay) + } + } + default: + return fmt.Errorf("unknown session behavior %#v", session.Behavior) + } + + // Delete any check mappings. + mappings, err := tx.Get("session_checks", "session", sessionID) + if err != nil { + return fmt.Errorf("failed session checks lookup: %s", err) + } + { + var objs []interface{} + for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { + objs = append(objs, mapping) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, obj := range objs { + if err := tx.Delete("session_checks", obj); err != nil { + return fmt.Errorf("failed deleting session check: %s", err) + } + } + } + + // Delete any prepared queries. + queries, err := tx.Get("prepared-queries", "session", sessionID) + if err != nil { + return fmt.Errorf("failed prepared query lookup: %s", err) + } + { + var ids []string + for wrapped := queries.Next(); wrapped != nil; wrapped = queries.Next() { + ids = append(ids, toPreparedQuery(wrapped).ID) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, id := range ids { + if err := s.preparedQueryDeleteTxn(tx, idx, watches, id); err != nil { + return fmt.Errorf("failed prepared query delete: %s", err) + } + } + } + + watches.Arm("sessions") + return nil +} diff --git a/consul/state/session_test.go b/consul/state/session_test.go new file mode 100644 index 0000000000..3e435e7e16 --- /dev/null +++ b/consul/state/session_test.go @@ -0,0 +1,911 @@ +package state + +import ( + "fmt" + "reflect" + "strings" + "testing" + "time" + + "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/consul/types" +) + +func TestStateStore_SessionCreate_SessionGet(t *testing.T) { + s := testStateStore(t) + + // SessionGet returns nil if the session doesn't exist + idx, session, err := s.SessionGet(testUUID()) + if session != nil || err != nil { + t.Fatalf("expected (nil, nil), got: (%#v, %#v)", session, err) + } + if idx != 0 { + t.Fatalf("bad index: %d", idx) + } + + // Registering without a session ID is disallowed + err = s.SessionCreate(1, &structs.Session{}) + if err != ErrMissingSessionID { + t.Fatalf("expected %#v, got: %#v", ErrMissingSessionID, err) + } + + // Invalid session behavior throws error + sess := &structs.Session{ + ID: testUUID(), + Behavior: "nope", + } + err = s.SessionCreate(1, sess) + if err == nil || !strings.Contains(err.Error(), "session behavior") { + t.Fatalf("expected session behavior error, got: %#v", err) + } + + // Registering with an unknown node is disallowed + sess = &structs.Session{ID: testUUID()} + if err := s.SessionCreate(1, sess); err != ErrMissingNode { + t.Fatalf("expected %#v, got: %#v", ErrMissingNode, err) + } + + // None of the errored operations modified the index + if idx := s.maxIndex("sessions"); idx != 0 { + t.Fatalf("bad index: %d", idx) + } + + // Valid session is able to register + testRegisterNode(t, s, 1, "node1") + sess = &structs.Session{ + ID: testUUID(), + Node: "node1", + } + if err := s.SessionCreate(2, sess); err != nil { + t.Fatalf("err: %s", err) + } + if idx := s.maxIndex("sessions"); idx != 2 { + t.Fatalf("bad index: %s", err) + } + + // Retrieve the session again + idx, session, err = s.SessionGet(sess.ID) + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + // Ensure the session looks correct and was assigned the + // proper default value for session behavior. + expect := &structs.Session{ + ID: sess.ID, + Behavior: structs.SessionKeysRelease, + Node: "node1", + RaftIndex: structs.RaftIndex{ + CreateIndex: 2, + ModifyIndex: 2, + }, + } + if !reflect.DeepEqual(expect, session) { + t.Fatalf("bad session: %#v", session) + } + + // Registering with a non-existent check is disallowed + sess = &structs.Session{ + ID: testUUID(), + Node: "node1", + Checks: []types.CheckID{"check1"}, + } + err = s.SessionCreate(3, sess) + if err == nil || !strings.Contains(err.Error(), "Missing check") { + t.Fatalf("expected missing check error, got: %#v", err) + } + + // Registering with a critical check is disallowed + testRegisterCheck(t, s, 3, "node1", "", "check1", structs.HealthCritical) + err = s.SessionCreate(4, sess) + if err == nil || !strings.Contains(err.Error(), structs.HealthCritical) { + t.Fatalf("expected critical state error, got: %#v", err) + } + + // Registering with a healthy check succeeds + testRegisterCheck(t, s, 4, "node1", "", "check1", structs.HealthPassing) + if err := s.SessionCreate(5, sess); err != nil { + t.Fatalf("err: %s", err) + } + + // Register a session against two checks. + testRegisterCheck(t, s, 5, "node1", "", "check2", structs.HealthPassing) + sess2 := &structs.Session{ + ID: testUUID(), + Node: "node1", + Checks: []types.CheckID{"check1", "check2"}, + } + if err := s.SessionCreate(6, sess2); err != nil { + t.Fatalf("err: %s", err) + } + + tx := s.db.Txn(false) + defer tx.Abort() + + // Check mappings were inserted + { + check, err := tx.First("session_checks", "session", sess.ID) + if err != nil { + t.Fatalf("err: %s", err) + } + if check == nil { + t.Fatalf("missing session check") + } + expectCheck := &sessionCheck{ + Node: "node1", + CheckID: "check1", + Session: sess.ID, + } + if actual := check.(*sessionCheck); !reflect.DeepEqual(actual, expectCheck) { + t.Fatalf("expected %#v, got: %#v", expectCheck, actual) + } + } + checks, err := tx.Get("session_checks", "session", sess2.ID) + if err != nil { + t.Fatalf("err: %s", err) + } + for i, check := 0, checks.Next(); check != nil; i, check = i+1, checks.Next() { + expectCheck := &sessionCheck{ + Node: "node1", + CheckID: types.CheckID(fmt.Sprintf("check%d", i+1)), + Session: sess2.ID, + } + if actual := check.(*sessionCheck); !reflect.DeepEqual(actual, expectCheck) { + t.Fatalf("expected %#v, got: %#v", expectCheck, actual) + } + } + + // Pulling a nonexistent session gives the table index. + idx, session, err = s.SessionGet(testUUID()) + if err != nil { + t.Fatalf("err: %s", err) + } + if session != nil { + t.Fatalf("expected not to get a session: %v", session) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } +} + +func TegstStateStore_SessionList(t *testing.T) { + s := testStateStore(t) + + // Listing when no sessions exist returns nil + idx, res, err := s.SessionList() + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Register some nodes + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + testRegisterNode(t, s, 3, "node3") + + // Create some sessions in the state store + sessions := structs.Sessions{ + &structs.Session{ + ID: testUUID(), + Node: "node1", + Behavior: structs.SessionKeysDelete, + }, + &structs.Session{ + ID: testUUID(), + Node: "node2", + Behavior: structs.SessionKeysRelease, + }, + &structs.Session{ + ID: testUUID(), + Node: "node3", + Behavior: structs.SessionKeysDelete, + }, + } + for i, session := range sessions { + if err := s.SessionCreate(uint64(4+i), session); err != nil { + t.Fatalf("err: %s", err) + } + } + + // List out all of the sessions + idx, sessionList, err := s.SessionList() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + if !reflect.DeepEqual(sessionList, sessions) { + t.Fatalf("bad: %#v", sessions) + } +} + +func TestStateStore_NodeSessions(t *testing.T) { + s := testStateStore(t) + + // Listing sessions with no results returns nil + idx, res, err := s.NodeSessions("node1") + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Create the nodes + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + + // Register some sessions with the nodes + sessions1 := structs.Sessions{ + &structs.Session{ + ID: testUUID(), + Node: "node1", + }, + &structs.Session{ + ID: testUUID(), + Node: "node1", + }, + } + sessions2 := []*structs.Session{ + &structs.Session{ + ID: testUUID(), + Node: "node2", + }, + &structs.Session{ + ID: testUUID(), + Node: "node2", + }, + } + for i, sess := range append(sessions1, sessions2...) { + if err := s.SessionCreate(uint64(3+i), sess); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Query all of the sessions associated with a specific + // node in the state store. + idx, res, err = s.NodeSessions("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(res) != len(sessions1) { + t.Fatalf("bad: %#v", res) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + idx, res, err = s.NodeSessions("node2") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(res) != len(sessions2) { + t.Fatalf("bad: %#v", res) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_SessionDestroy(t *testing.T) { + s := testStateStore(t) + + // Session destroy is idempotent and returns no error + // if the session doesn't exist. + if err := s.SessionDestroy(1, testUUID()); err != nil { + t.Fatalf("err: %s", err) + } + + // Ensure the index was not updated if nothing was destroyed. + if idx := s.maxIndex("sessions"); idx != 0 { + t.Fatalf("bad index: %d", idx) + } + + // Register a node. + testRegisterNode(t, s, 1, "node1") + + // Register a new session + sess := &structs.Session{ + ID: testUUID(), + Node: "node1", + } + if err := s.SessionCreate(2, sess); err != nil { + t.Fatalf("err: %s", err) + } + + // Destroy the session. + if err := s.SessionDestroy(3, sess.ID); err != nil { + t.Fatalf("err: %s", err) + } + + // Check that the index was updated + if idx := s.maxIndex("sessions"); idx != 3 { + t.Fatalf("bad index: %d", idx) + } + + // Make sure the session is really gone. + tx := s.db.Txn(false) + sessions, err := tx.Get("sessions", "id") + if err != nil || sessions.Next() != nil { + t.Fatalf("session should not exist") + } + tx.Abort() +} + +func TestStateStore_Session_Snapshot_Restore(t *testing.T) { + s := testStateStore(t) + + // Register some nodes and checks. + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + testRegisterNode(t, s, 3, "node3") + testRegisterCheck(t, s, 4, "node1", "", "check1", structs.HealthPassing) + + // Create some sessions in the state store. + session1 := testUUID() + sessions := structs.Sessions{ + &structs.Session{ + ID: session1, + Node: "node1", + Behavior: structs.SessionKeysDelete, + Checks: []types.CheckID{"check1"}, + }, + &structs.Session{ + ID: testUUID(), + Node: "node2", + Behavior: structs.SessionKeysRelease, + LockDelay: 10 * time.Second, + }, + &structs.Session{ + ID: testUUID(), + Node: "node3", + Behavior: structs.SessionKeysDelete, + TTL: "1.5s", + }, + } + for i, session := range sessions { + if err := s.SessionCreate(uint64(5+i), session); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Snapshot the sessions. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + if err := s.SessionDestroy(8, session1); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 7 { + t.Fatalf("bad index: %d", idx) + } + iter, err := snap.Sessions() + if err != nil { + t.Fatalf("err: %s", err) + } + var dump structs.Sessions + for session := iter.Next(); session != nil; session = iter.Next() { + sess := session.(*structs.Session) + dump = append(dump, sess) + + found := false + for i, _ := range sessions { + if sess.ID == sessions[i].ID { + if !reflect.DeepEqual(sess, sessions[i]) { + t.Fatalf("bad: %#v", sess) + } + found = true + } + } + if !found { + t.Fatalf("bad: %#v", sess) + } + } + + // Restore the sessions into a new state store. + func() { + s := testStateStore(t) + restore := s.Restore() + for _, session := range dump { + if err := restore.Session(session); err != nil { + t.Fatalf("err: %s", err) + } + } + restore.Commit() + + // Read the restored sessions back out and verify that they + // match. + idx, res, err := s.SessionList() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + for _, sess := range res { + found := false + for i, _ := range sessions { + if sess.ID == sessions[i].ID { + if !reflect.DeepEqual(sess, sessions[i]) { + t.Fatalf("bad: %#v", sess) + } + found = true + } + } + if !found { + t.Fatalf("bad: %#v", sess) + } + } + + // Check that the index was updated. + if idx := s.maxIndex("sessions"); idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // Manually verify that the session check mapping got restored. + tx := s.db.Txn(false) + defer tx.Abort() + + check, err := tx.First("session_checks", "session", session1) + if err != nil { + t.Fatalf("err: %s", err) + } + if check == nil { + t.Fatalf("missing session check") + } + expectCheck := &sessionCheck{ + Node: "node1", + CheckID: "check1", + Session: session1, + } + if actual := check.(*sessionCheck); !reflect.DeepEqual(actual, expectCheck) { + t.Fatalf("expected %#v, got: %#v", expectCheck, actual) + } + }() +} + +func TestStateStore_Session_Watches(t *testing.T) { + s := testStateStore(t) + + // Register a test node. + testRegisterNode(t, s, 1, "node1") + + // This just covers the basics. The session invalidation tests above + // cover the more nuanced multiple table watches. + session := testUUID() + verifyWatch(t, s.getTableWatch("sessions"), func() { + sess := &structs.Session{ + ID: session, + Node: "node1", + Behavior: structs.SessionKeysDelete, + } + if err := s.SessionCreate(2, sess); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("sessions"), func() { + if err := s.SessionDestroy(3, session); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("sessions"), func() { + restore := s.Restore() + sess := &structs.Session{ + ID: session, + Node: "node1", + Behavior: structs.SessionKeysDelete, + } + if err := restore.Session(sess); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + }) +} + +func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + } + if err := s.SessionCreate(14, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Delete the node and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { + if err := s.DeleteNode(15, "foo"); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + + // Lookup by ID, should be nil. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 15 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_Session_Invalidate_DeleteService(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + if err := s.EnsureNode(11, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + if err := s.EnsureService(12, "foo", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}); err != nil { + t.Fatalf("err: %v", err) + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "api", + Name: "Can connect", + Status: structs.HealthPassing, + ServiceID: "api", + } + if err := s.EnsureCheck(13, check); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + Checks: []types.CheckID{"api"}, + } + if err := s.SessionCreate(14, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Delete the service and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.DeleteService(15, "foo", "api"); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + }) + + // Lookup by ID, should be nil. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 15 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_Session_Invalidate_Critical_Check(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "bar", + Status: structs.HealthPassing, + } + if err := s.EnsureCheck(13, check); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + Checks: []types.CheckID{"bar"}, + } + if err := s.SessionCreate(14, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Invalidate the check and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + check.Status = structs.HealthCritical + if err := s.EnsureCheck(15, check); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + + // Lookup by ID, should be nil. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 15 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_Session_Invalidate_DeleteCheck(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "bar", + Status: structs.HealthPassing, + } + if err := s.EnsureCheck(13, check); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + Checks: []types.CheckID{"bar"}, + } + if err := s.SessionCreate(14, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Delete the check and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.DeleteCheck(15, "foo", "bar"); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + + // Lookup by ID, should be nil. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 15 { + t.Fatalf("bad index: %d", idx) + } + + // Manually make sure the session checks mapping is clear. + tx := s.db.Txn(false) + mapping, err := tx.First("session_checks", "session", session.ID) + if err != nil { + t.Fatalf("err: %s", err) + } + if mapping != nil { + t.Fatalf("unexpected session check") + } + tx.Abort() +} + +func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + LockDelay: 50 * time.Millisecond, + } + if err := s.SessionCreate(4, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Lock a key with the session. + d := &structs.DirEntry{ + Key: "/foo", + Flags: 42, + Value: []byte("test"), + Session: session.ID, + } + ok, err := s.KVSLock(5, d) + if err != nil { + t.Fatalf("err: %v", err) + } + if !ok { + t.Fatalf("unexpected fail") + } + + // Delete the node and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.GetKVSWatch("/f"), func() { + if err := s.DeleteNode(6, "foo"); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + }) + + // Lookup by ID, should be nil. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Key should be unlocked. + idx, d2, err := s.KVSGet("/foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if d2.ModifyIndex != 6 { + t.Fatalf("bad index: %v", d2.ModifyIndex) + } + if d2.LockIndex != 1 { + t.Fatalf("bad: %v", *d2) + } + if d2.Session != "" { + t.Fatalf("bad: %v", *d2) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Key should have a lock delay. + expires := s.KVSLockDelay("/foo") + if expires.Before(time.Now().Add(30 * time.Millisecond)) { + t.Fatalf("Bad: %v", expires) + } +} + +func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + LockDelay: 50 * time.Millisecond, + Behavior: structs.SessionKeysDelete, + } + if err := s.SessionCreate(4, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Lock a key with the session. + d := &structs.DirEntry{ + Key: "/bar", + Flags: 42, + Value: []byte("test"), + Session: session.ID, + } + ok, err := s.KVSLock(5, d) + if err != nil { + t.Fatalf("err: %v", err) + } + if !ok { + t.Fatalf("unexpected fail") + } + + // Delete the node and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.GetKVSWatch("/b"), func() { + if err := s.DeleteNode(6, "foo"); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + }) + + // Lookup by ID, should be nil. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Key should be deleted. + idx, d2, err := s.KVSGet("/bar") + if err != nil { + t.Fatalf("err: %s", err) + } + if d2 != nil { + t.Fatalf("unexpected deleted key") + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Key should have a lock delay. + expires := s.KVSLockDelay("/bar") + if expires.Before(time.Now().Add(30 * time.Millisecond)) { + t.Fatalf("Bad: %v", expires) + } +} + +func TestStateStore_Session_Invalidate_PreparedQuery_Delete(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + testRegisterNode(t, s, 1, "foo") + testRegisterService(t, s, 2, "foo", "redis") + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + } + if err := s.SessionCreate(3, session); err != nil { + t.Fatalf("err: %v", err) + } + query := &structs.PreparedQuery{ + ID: testUUID(), + Session: session.ID, + Service: structs.ServiceQuery{ + Service: "redis", + }, + } + if err := s.PreparedQuerySet(4, query); err != nil { + t.Fatalf("err: %s", err) + } + + // Invalidate the session and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("prepared-queries"), func() { + if err := s.SessionDestroy(5, session.ID); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + + // Make sure the session is gone. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 5 { + t.Fatalf("bad index: %d", idx) + } + + // Make sure the query is gone and the index is updated. + idx, q2, err := s.PreparedQueryGet(query.ID) + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 5 { + t.Fatalf("bad index: %d", idx) + } + if q2 != nil { + t.Fatalf("bad: %v", q2) + } +} diff --git a/consul/state/state_store.go b/consul/state/state_store.go index b11ac91ee6..b9b6467cbc 100644 --- a/consul/state/state_store.go +++ b/consul/state/state_store.go @@ -3,13 +3,9 @@ package state import ( "errors" "fmt" - "strings" - "time" - "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/types" "github.com/hashicorp/go-memdb" - "github.com/hashicorp/serf/coordinate" ) var ( @@ -141,62 +137,6 @@ func (s *StateSnapshot) Close() { s.tx.Abort() } -// Nodes is used to pull the full list of nodes for use during snapshots. -func (s *StateSnapshot) Nodes() (memdb.ResultIterator, error) { - iter, err := s.tx.Get("nodes", "id") - if err != nil { - return nil, err - } - return iter, nil -} - -// Services is used to pull the full list of services for a given node for use -// during snapshots. -func (s *StateSnapshot) Services(node string) (memdb.ResultIterator, error) { - iter, err := s.tx.Get("services", "node", node) - if err != nil { - return nil, err - } - return iter, nil -} - -// Checks is used to pull the full list of checks for a given node for use -// during snapshots. -func (s *StateSnapshot) Checks(node string) (memdb.ResultIterator, error) { - iter, err := s.tx.Get("checks", "node", node) - if err != nil { - return nil, err - } - return iter, nil -} - -// Sessions is used to pull the full list of sessions for use during snapshots. -func (s *StateSnapshot) Sessions() (memdb.ResultIterator, error) { - iter, err := s.tx.Get("sessions", "id") - if err != nil { - return nil, err - } - return iter, nil -} - -// ACLs is used to pull all the ACLs from the snapshot. -func (s *StateSnapshot) ACLs() (memdb.ResultIterator, error) { - iter, err := s.tx.Get("acls", "id") - if err != nil { - return nil, err - } - return iter, nil -} - -// Coordinates is used to pull all the coordinates from the snapshot. -func (s *StateSnapshot) Coordinates() (memdb.ResultIterator, error) { - iter, err := s.tx.Get("coordinates", "id") - if err != nil { - return nil, err - } - return iter, nil -} - // Restore is used to efficiently manage restoring a large amount of data into // the state store. It works by doing all the restores inside of a single // transaction. @@ -223,77 +163,6 @@ func (s *StateRestore) Commit() { s.tx.Commit() } -// Registration is used to make sure a node, service, and check registration is -// performed within a single transaction to avoid race conditions on state -// updates. -func (s *StateRestore) Registration(idx uint64, req *structs.RegisterRequest) error { - if err := s.store.ensureRegistrationTxn(s.tx, idx, s.watches, req); err != nil { - return err - } - return nil -} - -// Session is used when restoring from a snapshot. For general inserts, use -// SessionCreate. -func (s *StateRestore) Session(sess *structs.Session) error { - // Insert the session. - if err := s.tx.Insert("sessions", sess); err != nil { - return fmt.Errorf("failed inserting session: %s", err) - } - - // Insert the check mappings. - for _, checkID := range sess.Checks { - mapping := &sessionCheck{ - Node: sess.Node, - CheckID: checkID, - Session: sess.ID, - } - if err := s.tx.Insert("session_checks", mapping); err != nil { - return fmt.Errorf("failed inserting session check mapping: %s", err) - } - } - - // Update the index. - if err := indexUpdateMaxTxn(s.tx, sess.ModifyIndex, "sessions"); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - s.watches.Arm("sessions") - return nil -} - -// ACL is used when restoring from a snapshot. For general inserts, use ACLSet. -func (s *StateRestore) ACL(acl *structs.ACL) error { - if err := s.tx.Insert("acls", acl); err != nil { - return fmt.Errorf("failed restoring acl: %s", err) - } - - if err := indexUpdateMaxTxn(s.tx, acl.ModifyIndex, "acls"); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - s.watches.Arm("acls") - return nil -} - -// Coordinates is used when restoring from a snapshot. For general inserts, use -// CoordinateBatchUpdate. We do less vetting of the updates here because they -// already got checked on the way in during a batch update. -func (s *StateRestore) Coordinates(idx uint64, updates structs.Coordinates) error { - for _, update := range updates { - if err := s.tx.Insert("coordinates", update); err != nil { - return fmt.Errorf("failed restoring coordinate: %s", err) - } - } - - if err := indexUpdateMaxTxn(s.tx, idx, "coordinates"); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - s.watches.Arm("coordinates") - return nil -} - // maxIndex is a helper used to retrieve the highest known index // amongst a set of tables in the db. func (s *StateStore) maxIndex(tables ...string) uint64 { @@ -398,1633 +267,3 @@ func (s *StateStore) GetQueryWatch(method string) Watch { func (s *StateStore) GetKVSWatch(prefix string) Watch { return s.kvsWatch.NewPrefixWatch(prefix) } - -// EnsureRegistration is used to make sure a node, service, and check -// registration is performed within a single transaction to avoid race -// conditions on state updates. -func (s *StateStore) EnsureRegistration(idx uint64, req *structs.RegisterRequest) error { - tx := s.db.Txn(true) - defer tx.Abort() - - watches := NewDumbWatchManager(s.tableWatches) - if err := s.ensureRegistrationTxn(tx, idx, watches, req); err != nil { - return err - } - - tx.Defer(func() { watches.Notify() }) - tx.Commit() - return nil -} - -// ensureRegistrationTxn is used to make sure a node, service, and check -// registration is performed within a single transaction to avoid race -// conditions on state updates. -func (s *StateStore) ensureRegistrationTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, - req *structs.RegisterRequest) error { - // Add the node. - node := &structs.Node{ - Node: req.Node, - Address: req.Address, - TaggedAddresses: req.TaggedAddresses, - Meta: req.NodeMeta, - } - if err := s.ensureNodeTxn(tx, idx, watches, node); err != nil { - return fmt.Errorf("failed inserting node: %s", err) - } - - // Add the service, if any. - if req.Service != nil { - if err := s.ensureServiceTxn(tx, idx, watches, req.Node, req.Service); err != nil { - return fmt.Errorf("failed inserting service: %s", err) - } - } - - // TODO (slackpad) In Consul 0.8 ban checks that don't have the same - // node as the top-level registration. This is just weird to be able to - // update unrelated nodes' checks from in here. In 0.7.2 we banned this - // up in the ACL check since that's guarded behind an opt-in flag until - // Consul 0.8. - - // Add the checks, if any. - if req.Check != nil { - if err := s.ensureCheckTxn(tx, idx, watches, req.Check); err != nil { - return fmt.Errorf("failed inserting check: %s", err) - } - } - for _, check := range req.Checks { - if err := s.ensureCheckTxn(tx, idx, watches, check); err != nil { - return fmt.Errorf("failed inserting check: %s", err) - } - } - - return nil -} - -// EnsureNode is used to upsert node registration or modification. -func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error { - tx := s.db.Txn(true) - defer tx.Abort() - - // Call the node upsert - watches := NewDumbWatchManager(s.tableWatches) - if err := s.ensureNodeTxn(tx, idx, watches, node); err != nil { - return err - } - - tx.Defer(func() { watches.Notify() }) - tx.Commit() - return nil -} - -// ensureNodeTxn is the inner function called to actually create a node -// registration or modify an existing one in the state store. It allows -// passing in a memdb transaction so it may be part of a larger txn. -func (s *StateStore) ensureNodeTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, - node *structs.Node) error { - // Check for an existing node - existing, err := tx.First("nodes", "id", node.Node) - if err != nil { - return fmt.Errorf("node lookup failed: %s", err) - } - - // Get the indexes - if existing != nil { - node.CreateIndex = existing.(*structs.Node).CreateIndex - node.ModifyIndex = idx - } else { - node.CreateIndex = idx - node.ModifyIndex = idx - } - - // Insert the node and update the index - if err := tx.Insert("nodes", node); err != nil { - return fmt.Errorf("failed inserting node: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"nodes", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - watches.Arm("nodes") - return nil -} - -// GetNode is used to retrieve a node registration by node ID. -func (s *StateStore) GetNode(id string) (uint64, *structs.Node, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("GetNode")...) - - // Retrieve the node from the state store - node, err := tx.First("nodes", "id", id) - if err != nil { - return 0, nil, fmt.Errorf("node lookup failed: %s", err) - } - if node != nil { - return idx, node.(*structs.Node), nil - } - return idx, nil, nil -} - -// Nodes is used to return all of the known nodes. -func (s *StateStore) Nodes() (uint64, structs.Nodes, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("Nodes")...) - - // Retrieve all of the nodes - nodes, err := tx.Get("nodes", "id") - if err != nil { - return 0, nil, fmt.Errorf("failed nodes lookup: %s", err) - } - - // Create and return the nodes list. - var results structs.Nodes - for node := nodes.Next(); node != nil; node = nodes.Next() { - results = append(results, node.(*structs.Node)) - } - return idx, results, nil -} - -// NodesByMeta is used to return all nodes with the given meta key/value pair. -func (s *StateStore) NodesByMeta(filters map[string]string) (uint64, structs.Nodes, error) { - if len(filters) > 1 { - return 0, nil, fmt.Errorf("multiple meta filters not supported") - } - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("Nodes")...) - - // Retrieve all of the nodes - var args []interface{} - for key, value := range filters { - args = append(args, key, value) - } - nodes, err := tx.Get("nodes", "meta", args...) - if err != nil { - return 0, nil, fmt.Errorf("failed nodes lookup: %s", err) - } - - // Create and return the nodes list. - var results structs.Nodes - for node := nodes.Next(); node != nil; node = nodes.Next() { - results = append(results, node.(*structs.Node)) - } - return idx, results, nil -} - -// DeleteNode is used to delete a given node by its ID. -func (s *StateStore) DeleteNode(idx uint64, nodeID string) error { - tx := s.db.Txn(true) - defer tx.Abort() - - // Call the node deletion. - if err := s.deleteNodeTxn(tx, idx, nodeID); err != nil { - return err - } - - tx.Commit() - return nil -} - -// deleteNodeTxn is the inner method used for removing a node from -// the store within a given transaction. -func (s *StateStore) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeID string) error { - // Look up the node. - node, err := tx.First("nodes", "id", nodeID) - if err != nil { - return fmt.Errorf("node lookup failed: %s", err) - } - if node == nil { - return nil - } - - // Use a watch manager since the inner functions can perform multiple - // ops per table. - watches := NewDumbWatchManager(s.tableWatches) - - // Delete all services associated with the node and update the service index. - services, err := tx.Get("services", "node", nodeID) - if err != nil { - return fmt.Errorf("failed service lookup: %s", err) - } - var sids []string - for service := services.Next(); service != nil; service = services.Next() { - sids = append(sids, service.(*structs.ServiceNode).ServiceID) - } - - // Do the delete in a separate loop so we don't trash the iterator. - for _, sid := range sids { - if err := s.deleteServiceTxn(tx, idx, watches, nodeID, sid); err != nil { - return err - } - } - - // Delete all checks associated with the node. This will invalidate - // sessions as necessary. - checks, err := tx.Get("checks", "node", nodeID) - if err != nil { - return fmt.Errorf("failed check lookup: %s", err) - } - var cids []types.CheckID - for check := checks.Next(); check != nil; check = checks.Next() { - cids = append(cids, check.(*structs.HealthCheck).CheckID) - } - - // Do the delete in a separate loop so we don't trash the iterator. - for _, cid := range cids { - if err := s.deleteCheckTxn(tx, idx, watches, nodeID, cid); err != nil { - return err - } - } - - // Delete any coordinate associated with this node. - coord, err := tx.First("coordinates", "id", nodeID) - if err != nil { - return fmt.Errorf("failed coordinate lookup: %s", err) - } - if coord != nil { - if err := tx.Delete("coordinates", coord); err != nil { - return fmt.Errorf("failed deleting coordinate: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"coordinates", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - watches.Arm("coordinates") - } - - // Delete the node and update the index. - if err := tx.Delete("nodes", node); err != nil { - return fmt.Errorf("failed deleting node: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"nodes", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - // Invalidate any sessions for this node. - sessions, err := tx.Get("sessions", "node", nodeID) - if err != nil { - return fmt.Errorf("failed session lookup: %s", err) - } - var ids []string - for sess := sessions.Next(); sess != nil; sess = sessions.Next() { - ids = append(ids, sess.(*structs.Session).ID) - } - - // Do the delete in a separate loop so we don't trash the iterator. - for _, id := range ids { - if err := s.deleteSessionTxn(tx, idx, watches, id); err != nil { - return fmt.Errorf("failed session delete: %s", err) - } - } - - watches.Arm("nodes") - tx.Defer(func() { watches.Notify() }) - return nil -} - -// EnsureService is called to upsert creation of a given NodeService. -func (s *StateStore) EnsureService(idx uint64, node string, svc *structs.NodeService) error { - tx := s.db.Txn(true) - defer tx.Abort() - - // Call the service registration upsert - watches := NewDumbWatchManager(s.tableWatches) - if err := s.ensureServiceTxn(tx, idx, watches, node, svc); err != nil { - return err - } - - tx.Defer(func() { watches.Notify() }) - tx.Commit() - return nil -} - -// ensureServiceTxn is used to upsert a service registration within an -// existing memdb transaction. -func (s *StateStore) ensureServiceTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, - node string, svc *structs.NodeService) error { - // Check for existing service - existing, err := tx.First("services", "id", node, svc.ID) - if err != nil { - return fmt.Errorf("failed service lookup: %s", err) - } - - // Create the service node entry and populate the indexes. Note that - // conversion doesn't populate any of the node-specific information - // (Address and TaggedAddresses). That's always populated when we read - // from the state store. - entry := svc.ToServiceNode(node) - if existing != nil { - entry.CreateIndex = existing.(*structs.ServiceNode).CreateIndex - entry.ModifyIndex = idx - } else { - entry.CreateIndex = idx - entry.ModifyIndex = idx - } - - // Get the node - n, err := tx.First("nodes", "id", node) - if err != nil { - return fmt.Errorf("failed node lookup: %s", err) - } - if n == nil { - return ErrMissingNode - } - - // Insert the service and update the index - if err := tx.Insert("services", entry); err != nil { - return fmt.Errorf("failed inserting service: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"services", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - watches.Arm("services") - return nil -} - -// Services returns all services along with a list of associated tags. -func (s *StateStore) Services() (uint64, structs.Services, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("Services")...) - - // List all the services. - services, err := tx.Get("services", "id") - if err != nil { - return 0, nil, fmt.Errorf("failed querying services: %s", err) - } - - // Rip through the services and enumerate them and their unique set of - // tags. - unique := make(map[string]map[string]struct{}) - for service := services.Next(); service != nil; service = services.Next() { - svc := service.(*structs.ServiceNode) - tags, ok := unique[svc.ServiceName] - if !ok { - unique[svc.ServiceName] = make(map[string]struct{}) - tags = unique[svc.ServiceName] - } - for _, tag := range svc.ServiceTags { - tags[tag] = struct{}{} - } - } - - // Generate the output structure. - var results = make(structs.Services) - for service, tags := range unique { - results[service] = make([]string, 0) - for tag, _ := range tags { - results[service] = append(results[service], tag) - } - } - return idx, results, nil -} - -// Services returns all services, filtered by the given node metadata. -func (s *StateStore) ServicesByNodeMeta(filters map[string]string) (uint64, structs.Services, error) { - if len(filters) > 1 { - return 0, nil, fmt.Errorf("multiple meta filters not supported") - } - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ServiceNodes")...) - - // Retrieve all of the nodes with the meta k/v pair - var args []interface{} - for key, value := range filters { - args = append(args, key, value) - } - nodes, err := tx.Get("nodes", "meta", args...) - if err != nil { - return 0, nil, fmt.Errorf("failed nodes lookup: %s", err) - } - - // Populate the services map - unique := make(map[string]map[string]struct{}) - for node := nodes.Next(); node != nil; node = nodes.Next() { - n := node.(*structs.Node) - // List all the services on the node - services, err := tx.Get("services", "node", n.Node) - if err != nil { - return 0, nil, fmt.Errorf("failed querying services: %s", err) - } - - // Rip through the services and enumerate them and their unique set of - // tags. - for service := services.Next(); service != nil; service = services.Next() { - svc := service.(*structs.ServiceNode) - tags, ok := unique[svc.ServiceName] - if !ok { - unique[svc.ServiceName] = make(map[string]struct{}) - tags = unique[svc.ServiceName] - } - for _, tag := range svc.ServiceTags { - tags[tag] = struct{}{} - } - } - } - - // Generate the output structure. - var results = make(structs.Services) - for service, tags := range unique { - results[service] = make([]string, 0) - for tag, _ := range tags { - results[service] = append(results[service], tag) - } - } - return idx, results, nil -} - -// ServiceNodes returns the nodes associated with a given service name. -func (s *StateStore) ServiceNodes(serviceName string) (uint64, structs.ServiceNodes, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ServiceNodes")...) - - // List all the services. - services, err := tx.Get("services", "service", serviceName) - if err != nil { - return 0, nil, fmt.Errorf("failed service lookup: %s", err) - } - var results structs.ServiceNodes - for service := services.Next(); service != nil; service = services.Next() { - results = append(results, service.(*structs.ServiceNode)) - } - - // Fill in the address details. - results, err = s.parseServiceNodes(tx, results) - if err != nil { - return 0, nil, fmt.Errorf("failed parsing service nodes: %s", err) - } - return idx, results, nil -} - -// ServiceTagNodes returns the nodes associated with a given service, filtering -// out services that don't contain the given tag. -func (s *StateStore) ServiceTagNodes(service, tag string) (uint64, structs.ServiceNodes, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ServiceNodes")...) - - // List all the services. - services, err := tx.Get("services", "service", service) - if err != nil { - return 0, nil, fmt.Errorf("failed service lookup: %s", err) - } - - // Gather all the services and apply the tag filter. - var results structs.ServiceNodes - for service := services.Next(); service != nil; service = services.Next() { - svc := service.(*structs.ServiceNode) - if !serviceTagFilter(svc, tag) { - results = append(results, svc) - } - } - - // Fill in the address details. - results, err = s.parseServiceNodes(tx, results) - if err != nil { - return 0, nil, fmt.Errorf("failed parsing service nodes: %s", err) - } - return idx, results, nil -} - -// serviceTagFilter returns true (should filter) if the given service node -// doesn't contain the given tag. -func serviceTagFilter(sn *structs.ServiceNode, tag string) bool { - tag = strings.ToLower(tag) - - // Look for the lower cased version of the tag. - for _, t := range sn.ServiceTags { - if strings.ToLower(t) == tag { - return false - } - } - - // If we didn't hit the tag above then we should filter. - return true -} - -// parseServiceNodes iterates over a services query and fills in the node details, -// returning a ServiceNodes slice. -func (s *StateStore) parseServiceNodes(tx *memdb.Txn, services structs.ServiceNodes) (structs.ServiceNodes, error) { - var results structs.ServiceNodes - for _, sn := range services { - // Note that we have to clone here because we don't want to - // modify the node-related fields on the object in the database, - // which is what we are referencing. - s := sn.PartialClone() - - // Grab the corresponding node record. - n, err := tx.First("nodes", "id", sn.Node) - if err != nil { - return nil, fmt.Errorf("failed node lookup: %s", err) - } - - // Populate the node-related fields. The tagged addresses may be - // used by agents to perform address translation if they are - // configured to do that. - node := n.(*structs.Node) - s.Address = node.Address - s.TaggedAddresses = node.TaggedAddresses - s.NodeMeta = node.Meta - - results = append(results, s) - } - return results, nil -} - -// NodeService is used to retrieve a specific service associated with the given -// node. -func (s *StateStore) NodeService(nodeID string, serviceID string) (uint64, *structs.NodeService, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeService")...) - - // Query the service - service, err := tx.First("services", "id", nodeID, serviceID) - if err != nil { - return 0, nil, fmt.Errorf("failed querying service for node %q: %s", nodeID, err) - } - - if service != nil { - return idx, service.(*structs.ServiceNode).ToNodeService(), nil - } else { - return idx, nil, nil - } -} - -// NodeServices is used to query service registrations by node ID. -func (s *StateStore) NodeServices(nodeID string) (uint64, *structs.NodeServices, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeServices")...) - - // Query the node - n, err := tx.First("nodes", "id", nodeID) - if err != nil { - return 0, nil, fmt.Errorf("node lookup failed: %s", err) - } - if n == nil { - return 0, nil, nil - } - node := n.(*structs.Node) - - // Read all of the services - services, err := tx.Get("services", "node", nodeID) - if err != nil { - return 0, nil, fmt.Errorf("failed querying services for node %q: %s", nodeID, err) - } - - // Initialize the node services struct - ns := &structs.NodeServices{ - Node: node, - Services: make(map[string]*structs.NodeService), - } - - // Add all of the services to the map. - for service := services.Next(); service != nil; service = services.Next() { - svc := service.(*structs.ServiceNode).ToNodeService() - ns.Services[svc.ID] = svc - } - - return idx, ns, nil -} - -// DeleteService is used to delete a given service associated with a node. -func (s *StateStore) DeleteService(idx uint64, nodeID, serviceID string) error { - tx := s.db.Txn(true) - defer tx.Abort() - - // Call the service deletion - watches := NewDumbWatchManager(s.tableWatches) - if err := s.deleteServiceTxn(tx, idx, watches, nodeID, serviceID); err != nil { - return err - } - - tx.Defer(func() { watches.Notify() }) - tx.Commit() - return nil -} - -// deleteServiceTxn is the inner method called to remove a service -// registration within an existing transaction. -func (s *StateStore) deleteServiceTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, nodeID, serviceID string) error { - // Look up the service. - service, err := tx.First("services", "id", nodeID, serviceID) - if err != nil { - return fmt.Errorf("failed service lookup: %s", err) - } - if service == nil { - return nil - } - - // Delete any checks associated with the service. This will invalidate - // sessions as necessary. - checks, err := tx.Get("checks", "node_service", nodeID, serviceID) - if err != nil { - return fmt.Errorf("failed service check lookup: %s", err) - } - var cids []types.CheckID - for check := checks.Next(); check != nil; check = checks.Next() { - cids = append(cids, check.(*structs.HealthCheck).CheckID) - } - - // Do the delete in a separate loop so we don't trash the iterator. - for _, cid := range cids { - if err := s.deleteCheckTxn(tx, idx, watches, nodeID, cid); err != nil { - return err - } - } - - // Update the index. - if err := tx.Insert("index", &IndexEntry{"checks", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - // Delete the service and update the index - if err := tx.Delete("services", service); err != nil { - return fmt.Errorf("failed deleting service: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"services", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - watches.Arm("services") - return nil -} - -// EnsureCheck is used to store a check registration in the db. -func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error { - tx := s.db.Txn(true) - defer tx.Abort() - - // Call the check registration - watches := NewDumbWatchManager(s.tableWatches) - if err := s.ensureCheckTxn(tx, idx, watches, hc); err != nil { - return err - } - - tx.Defer(func() { watches.Notify() }) - tx.Commit() - return nil -} - -// ensureCheckTransaction is used as the inner method to handle inserting -// a health check into the state store. It ensures safety against inserting -// checks with no matching node or service. -func (s *StateStore) ensureCheckTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, - hc *structs.HealthCheck) error { - // Check if we have an existing health check - existing, err := tx.First("checks", "id", hc.Node, string(hc.CheckID)) - if err != nil { - return fmt.Errorf("failed health check lookup: %s", err) - } - - // Set the indexes - if existing != nil { - hc.CreateIndex = existing.(*structs.HealthCheck).CreateIndex - hc.ModifyIndex = idx - } else { - hc.CreateIndex = idx - hc.ModifyIndex = idx - } - - // Use the default check status if none was provided - if hc.Status == "" { - hc.Status = structs.HealthCritical - } - - // Get the node - node, err := tx.First("nodes", "id", hc.Node) - if err != nil { - return fmt.Errorf("failed node lookup: %s", err) - } - if node == nil { - return ErrMissingNode - } - - // If the check is associated with a service, check that we have - // a registration for the service. - if hc.ServiceID != "" { - service, err := tx.First("services", "id", hc.Node, hc.ServiceID) - if err != nil { - return fmt.Errorf("failed service lookup: %s", err) - } - if service == nil { - return ErrMissingService - } - - // Copy in the service name - hc.ServiceName = service.(*structs.ServiceNode).ServiceName - } - - // Delete any sessions for this check if the health is critical. - if hc.Status == structs.HealthCritical { - mappings, err := tx.Get("session_checks", "node_check", hc.Node, string(hc.CheckID)) - if err != nil { - return fmt.Errorf("failed session checks lookup: %s", err) - } - - var ids []string - for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { - ids = append(ids, mapping.(*sessionCheck).Session) - } - - // Delete the session in a separate loop so we don't trash the - // iterator. - watches := NewDumbWatchManager(s.tableWatches) - for _, id := range ids { - if err := s.deleteSessionTxn(tx, idx, watches, id); err != nil { - return fmt.Errorf("failed deleting session: %s", err) - } - } - tx.Defer(func() { watches.Notify() }) - } - - // Persist the check registration in the db. - if err := tx.Insert("checks", hc); err != nil { - return fmt.Errorf("failed inserting check: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"checks", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - watches.Arm("checks") - return nil -} - -// NodeCheck is used to retrieve a specific check associated with the given -// node. -func (s *StateStore) NodeCheck(nodeID string, checkID types.CheckID) (uint64, *structs.HealthCheck, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeCheck")...) - - // Return the check. - check, err := tx.First("checks", "id", nodeID, string(checkID)) - if err != nil { - return 0, nil, fmt.Errorf("failed check lookup: %s", err) - } - if check != nil { - return idx, check.(*structs.HealthCheck), nil - } else { - return idx, nil, nil - } -} - -// NodeChecks is used to retrieve checks associated with the -// given node from the state store. -func (s *StateStore) NodeChecks(nodeID string) (uint64, structs.HealthChecks, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeChecks")...) - - // Return the checks. - checks, err := tx.Get("checks", "node", nodeID) - if err != nil { - return 0, nil, fmt.Errorf("failed check lookup: %s", err) - } - return s.parseChecks(idx, checks) -} - -// ServiceChecks is used to get all checks associated with a -// given service ID. The query is performed against a service -// _name_ instead of a service ID. -func (s *StateStore) ServiceChecks(serviceName string) (uint64, structs.HealthChecks, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ServiceChecks")...) - - // Return the checks. - checks, err := tx.Get("checks", "service", serviceName) - if err != nil { - return 0, nil, fmt.Errorf("failed check lookup: %s", err) - } - return s.parseChecks(idx, checks) -} - -// ChecksInState is used to query the state store for all checks -// which are in the provided state. -func (s *StateStore) ChecksInState(state string) (uint64, structs.HealthChecks, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ChecksInState")...) - - // Query all checks if HealthAny is passed - if state == structs.HealthAny { - checks, err := tx.Get("checks", "status") - if err != nil { - return 0, nil, fmt.Errorf("failed check lookup: %s", err) - } - return s.parseChecks(idx, checks) - } - - // Any other state we need to query for explicitly - checks, err := tx.Get("checks", "status", state) - if err != nil { - return 0, nil, fmt.Errorf("failed check lookup: %s", err) - } - return s.parseChecks(idx, checks) -} - -// parseChecks is a helper function used to deduplicate some -// repetitive code for returning health checks. -func (s *StateStore) parseChecks(idx uint64, iter memdb.ResultIterator) (uint64, structs.HealthChecks, error) { - // Gather the health checks and return them properly type casted. - var results structs.HealthChecks - for check := iter.Next(); check != nil; check = iter.Next() { - results = append(results, check.(*structs.HealthCheck)) - } - return idx, results, nil -} - -// DeleteCheck is used to delete a health check registration. -func (s *StateStore) DeleteCheck(idx uint64, node string, checkID types.CheckID) error { - tx := s.db.Txn(true) - defer tx.Abort() - - // Call the check deletion - watches := NewDumbWatchManager(s.tableWatches) - if err := s.deleteCheckTxn(tx, idx, watches, node, checkID); err != nil { - return err - } - - tx.Defer(func() { watches.Notify() }) - tx.Commit() - return nil -} - -// deleteCheckTxn is the inner method used to call a health -// check deletion within an existing transaction. -func (s *StateStore) deleteCheckTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, node string, checkID types.CheckID) error { - // Try to retrieve the existing health check. - hc, err := tx.First("checks", "id", node, string(checkID)) - if err != nil { - return fmt.Errorf("check lookup failed: %s", err) - } - if hc == nil { - return nil - } - - // Delete the check from the DB and update the index. - if err := tx.Delete("checks", hc); err != nil { - return fmt.Errorf("failed removing check: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"checks", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - // Delete any sessions for this check. - mappings, err := tx.Get("session_checks", "node_check", node, string(checkID)) - if err != nil { - return fmt.Errorf("failed session checks lookup: %s", err) - } - var ids []string - for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { - ids = append(ids, mapping.(*sessionCheck).Session) - } - - // Do the delete in a separate loop so we don't trash the iterator. - for _, id := range ids { - if err := s.deleteSessionTxn(tx, idx, watches, id); err != nil { - return fmt.Errorf("failed deleting session: %s", err) - } - } - - watches.Arm("checks") - return nil -} - -// CheckServiceNodes is used to query all nodes and checks for a given service -// The results are compounded into a CheckServiceNodes, and the index returned -// is the maximum index observed over any node, check, or service in the result -// set. -func (s *StateStore) CheckServiceNodes(serviceName string) (uint64, structs.CheckServiceNodes, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("CheckServiceNodes")...) - - // Query the state store for the service. - services, err := tx.Get("services", "service", serviceName) - if err != nil { - return 0, nil, fmt.Errorf("failed service lookup: %s", err) - } - - // Return the results. - var results structs.ServiceNodes - for service := services.Next(); service != nil; service = services.Next() { - results = append(results, service.(*structs.ServiceNode)) - } - return s.parseCheckServiceNodes(tx, idx, results, err) -} - -// CheckServiceTagNodes is used to query all nodes and checks for a given -// service, filtering out services that don't contain the given tag. The results -// are compounded into a CheckServiceNodes, and the index returned is the maximum -// index observed over any node, check, or service in the result set. -func (s *StateStore) CheckServiceTagNodes(serviceName, tag string) (uint64, structs.CheckServiceNodes, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("CheckServiceNodes")...) - - // Query the state store for the service. - services, err := tx.Get("services", "service", serviceName) - if err != nil { - return 0, nil, fmt.Errorf("failed service lookup: %s", err) - } - - // Return the results, filtering by tag. - var results structs.ServiceNodes - for service := services.Next(); service != nil; service = services.Next() { - svc := service.(*structs.ServiceNode) - if !serviceTagFilter(svc, tag) { - results = append(results, svc) - } - } - return s.parseCheckServiceNodes(tx, idx, results, err) -} - -// parseCheckServiceNodes is used to parse through a given set of services, -// and query for an associated node and a set of checks. This is the inner -// method used to return a rich set of results from a more simple query. -func (s *StateStore) parseCheckServiceNodes( - tx *memdb.Txn, idx uint64, services structs.ServiceNodes, - err error) (uint64, structs.CheckServiceNodes, error) { - if err != nil { - return 0, nil, err - } - - // Special-case the zero return value to nil, since this ends up in - // external APIs. - if len(services) == 0 { - return idx, nil, nil - } - - results := make(structs.CheckServiceNodes, 0, len(services)) - for _, sn := range services { - // Retrieve the node. - n, err := tx.First("nodes", "id", sn.Node) - if err != nil { - return 0, nil, fmt.Errorf("failed node lookup: %s", err) - } - if n == nil { - return 0, nil, ErrMissingNode - } - node := n.(*structs.Node) - - // We need to return the checks specific to the given service - // as well as the node itself. Unfortunately, memdb won't let - // us use the index to do the latter query so we have to pull - // them all and filter. - var checks structs.HealthChecks - iter, err := tx.Get("checks", "node", sn.Node) - if err != nil { - return 0, nil, err - } - for check := iter.Next(); check != nil; check = iter.Next() { - hc := check.(*structs.HealthCheck) - if hc.ServiceID == "" || hc.ServiceID == sn.ServiceID { - checks = append(checks, hc) - } - } - - // Append to the results. - results = append(results, structs.CheckServiceNode{ - Node: node, - Service: sn.ToNodeService(), - Checks: checks, - }) - } - - return idx, results, nil -} - -// NodeInfo is used to generate a dump of a single node. The dump includes -// all services and checks which are registered against the node. -func (s *StateStore) NodeInfo(node string) (uint64, structs.NodeDump, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeInfo")...) - - // Query the node by the passed node - nodes, err := tx.Get("nodes", "id", node) - if err != nil { - return 0, nil, fmt.Errorf("failed node lookup: %s", err) - } - return s.parseNodes(tx, idx, nodes) -} - -// NodeDump is used to generate a dump of all nodes. This call is expensive -// as it has to query every node, service, and check. The response can also -// be quite large since there is currently no filtering applied. -func (s *StateStore) NodeDump() (uint64, structs.NodeDump, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeDump")...) - - // Fetch all of the registered nodes - nodes, err := tx.Get("nodes", "id") - if err != nil { - return 0, nil, fmt.Errorf("failed node lookup: %s", err) - } - return s.parseNodes(tx, idx, nodes) -} - -// parseNodes takes an iterator over a set of nodes and returns a struct -// containing the nodes along with all of their associated services -// and/or health checks. -func (s *StateStore) parseNodes(tx *memdb.Txn, idx uint64, - iter memdb.ResultIterator) (uint64, structs.NodeDump, error) { - - var results structs.NodeDump - for n := iter.Next(); n != nil; n = iter.Next() { - node := n.(*structs.Node) - - // Create the wrapped node - dump := &structs.NodeInfo{ - Node: node.Node, - Address: node.Address, - TaggedAddresses: node.TaggedAddresses, - Meta: node.Meta, - } - - // Query the node services - services, err := tx.Get("services", "node", node.Node) - if err != nil { - return 0, nil, fmt.Errorf("failed services lookup: %s", err) - } - for service := services.Next(); service != nil; service = services.Next() { - ns := service.(*structs.ServiceNode).ToNodeService() - dump.Services = append(dump.Services, ns) - } - - // Query the node checks - checks, err := tx.Get("checks", "node", node.Node) - if err != nil { - return 0, nil, fmt.Errorf("failed node lookup: %s", err) - } - for check := checks.Next(); check != nil; check = checks.Next() { - hc := check.(*structs.HealthCheck) - dump.Checks = append(dump.Checks, hc) - } - - // Add the result to the slice - results = append(results, dump) - } - return idx, results, nil -} - -// SessionCreate is used to register a new session in the state store. -func (s *StateStore) SessionCreate(idx uint64, sess *structs.Session) error { - tx := s.db.Txn(true) - defer tx.Abort() - - // This code is technically able to (incorrectly) update an existing - // session but we never do that in practice. The upstream endpoint code - // always adds a unique ID when doing a create operation so we never hit - // an existing session again. It isn't worth the overhead to verify - // that here, but it's worth noting that we should never do this in the - // future. - - // Call the session creation - if err := s.sessionCreateTxn(tx, idx, sess); err != nil { - return err - } - - tx.Commit() - return nil -} - -// sessionCreateTxn is the inner method used for creating session entries in -// an open transaction. Any health checks registered with the session will be -// checked for failing status. Returns any error encountered. -func (s *StateStore) sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Session) error { - // Check that we have a session ID - if sess.ID == "" { - return ErrMissingSessionID - } - - // Verify the session behavior is valid - switch sess.Behavior { - case "": - // Release by default to preserve backwards compatibility - sess.Behavior = structs.SessionKeysRelease - case structs.SessionKeysRelease: - case structs.SessionKeysDelete: - default: - return fmt.Errorf("Invalid session behavior: %s", sess.Behavior) - } - - // Assign the indexes. ModifyIndex likely will not be used but - // we set it here anyways for sanity. - sess.CreateIndex = idx - sess.ModifyIndex = idx - - // Check that the node exists - node, err := tx.First("nodes", "id", sess.Node) - if err != nil { - return fmt.Errorf("failed node lookup: %s", err) - } - if node == nil { - return ErrMissingNode - } - - // Go over the session checks and ensure they exist. - for _, checkID := range sess.Checks { - check, err := tx.First("checks", "id", sess.Node, string(checkID)) - if err != nil { - return fmt.Errorf("failed check lookup: %s", err) - } - if check == nil { - return fmt.Errorf("Missing check '%s' registration", checkID) - } - - // Check that the check is not in critical state - status := check.(*structs.HealthCheck).Status - if status == structs.HealthCritical { - return fmt.Errorf("Check '%s' is in %s state", checkID, status) - } - } - - // Insert the session - if err := tx.Insert("sessions", sess); err != nil { - return fmt.Errorf("failed inserting session: %s", err) - } - - // Insert the check mappings - for _, checkID := range sess.Checks { - mapping := &sessionCheck{ - Node: sess.Node, - CheckID: checkID, - Session: sess.ID, - } - if err := tx.Insert("session_checks", mapping); err != nil { - return fmt.Errorf("failed inserting session check mapping: %s", err) - } - } - - // Update the index - if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - tx.Defer(func() { s.tableWatches["sessions"].Notify() }) - return nil -} - -// SessionGet is used to retrieve an active session from the state store. -func (s *StateStore) SessionGet(sessionID string) (uint64, *structs.Session, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("SessionGet")...) - - // Look up the session by its ID - session, err := tx.First("sessions", "id", sessionID) - if err != nil { - return 0, nil, fmt.Errorf("failed session lookup: %s", err) - } - if session != nil { - return idx, session.(*structs.Session), nil - } - return idx, nil, nil -} - -// SessionList returns a slice containing all of the active sessions. -func (s *StateStore) SessionList() (uint64, structs.Sessions, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("SessionList")...) - - // Query all of the active sessions. - sessions, err := tx.Get("sessions", "id") - if err != nil { - return 0, nil, fmt.Errorf("failed session lookup: %s", err) - } - - // Go over the sessions and create a slice of them. - var result structs.Sessions - for session := sessions.Next(); session != nil; session = sessions.Next() { - result = append(result, session.(*structs.Session)) - } - return idx, result, nil -} - -// NodeSessions returns a set of active sessions associated -// with the given node ID. The returned index is the highest -// index seen from the result set. -func (s *StateStore) NodeSessions(nodeID string) (uint64, structs.Sessions, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("NodeSessions")...) - - // Get all of the sessions which belong to the node - sessions, err := tx.Get("sessions", "node", nodeID) - if err != nil { - return 0, nil, fmt.Errorf("failed session lookup: %s", err) - } - - // Go over all of the sessions and return them as a slice - var result structs.Sessions - for session := sessions.Next(); session != nil; session = sessions.Next() { - result = append(result, session.(*structs.Session)) - } - return idx, result, nil -} - -// SessionDestroy is used to remove an active session. This will -// implicitly invalidate the session and invoke the specified -// session destroy behavior. -func (s *StateStore) SessionDestroy(idx uint64, sessionID string) error { - tx := s.db.Txn(true) - defer tx.Abort() - - // Call the session deletion. - watches := NewDumbWatchManager(s.tableWatches) - if err := s.deleteSessionTxn(tx, idx, watches, sessionID); err != nil { - return err - } - - tx.Defer(func() { watches.Notify() }) - tx.Commit() - return nil -} - -// deleteSessionTxn is the inner method, which is used to do the actual -// session deletion and handle session invalidation, watch triggers, etc. -func (s *StateStore) deleteSessionTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, sessionID string) error { - // Look up the session. - sess, err := tx.First("sessions", "id", sessionID) - if err != nil { - return fmt.Errorf("failed session lookup: %s", err) - } - if sess == nil { - return nil - } - - // Delete the session and write the new index. - if err := tx.Delete("sessions", sess); err != nil { - return fmt.Errorf("failed deleting session: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - // Enforce the max lock delay. - session := sess.(*structs.Session) - delay := session.LockDelay - if delay > structs.MaxLockDelay { - delay = structs.MaxLockDelay - } - - // Snag the current now time so that all the expirations get calculated - // the same way. - now := time.Now() - - // Get an iterator over all of the keys with the given session. - entries, err := tx.Get("kvs", "session", sessionID) - if err != nil { - return fmt.Errorf("failed kvs lookup: %s", err) - } - var kvs []interface{} - for entry := entries.Next(); entry != nil; entry = entries.Next() { - kvs = append(kvs, entry) - } - - // Invalidate any held locks. - switch session.Behavior { - case structs.SessionKeysRelease: - for _, obj := range kvs { - // Note that we clone here since we are modifying the - // returned object and want to make sure our set op - // respects the transaction we are in. - e := obj.(*structs.DirEntry).Clone() - e.Session = "" - if err := s.kvsSetTxn(tx, idx, e, true); err != nil { - return fmt.Errorf("failed kvs update: %s", err) - } - - // Apply the lock delay if present. - if delay > 0 { - s.lockDelay.SetExpiration(e.Key, now, delay) - } - } - case structs.SessionKeysDelete: - for _, obj := range kvs { - e := obj.(*structs.DirEntry) - if err := s.kvsDeleteTxn(tx, idx, e.Key); err != nil { - return fmt.Errorf("failed kvs delete: %s", err) - } - - // Apply the lock delay if present. - if delay > 0 { - s.lockDelay.SetExpiration(e.Key, now, delay) - } - } - default: - return fmt.Errorf("unknown session behavior %#v", session.Behavior) - } - - // Delete any check mappings. - mappings, err := tx.Get("session_checks", "session", sessionID) - if err != nil { - return fmt.Errorf("failed session checks lookup: %s", err) - } - { - var objs []interface{} - for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { - objs = append(objs, mapping) - } - - // Do the delete in a separate loop so we don't trash the iterator. - for _, obj := range objs { - if err := tx.Delete("session_checks", obj); err != nil { - return fmt.Errorf("failed deleting session check: %s", err) - } - } - } - - // Delete any prepared queries. - queries, err := tx.Get("prepared-queries", "session", sessionID) - if err != nil { - return fmt.Errorf("failed prepared query lookup: %s", err) - } - { - var ids []string - for wrapped := queries.Next(); wrapped != nil; wrapped = queries.Next() { - ids = append(ids, toPreparedQuery(wrapped).ID) - } - - // Do the delete in a separate loop so we don't trash the iterator. - for _, id := range ids { - if err := s.preparedQueryDeleteTxn(tx, idx, watches, id); err != nil { - return fmt.Errorf("failed prepared query delete: %s", err) - } - } - } - - watches.Arm("sessions") - return nil -} - -// ACLSet is used to insert an ACL rule into the state store. -func (s *StateStore) ACLSet(idx uint64, acl *structs.ACL) error { - tx := s.db.Txn(true) - defer tx.Abort() - - // Call set on the ACL - if err := s.aclSetTxn(tx, idx, acl); err != nil { - return err - } - - tx.Commit() - return nil -} - -// aclSetTxn is the inner method used to insert an ACL rule with the -// proper indexes into the state store. -func (s *StateStore) aclSetTxn(tx *memdb.Txn, idx uint64, acl *structs.ACL) error { - // Check that the ID is set - if acl.ID == "" { - return ErrMissingACLID - } - - // Check for an existing ACL - existing, err := tx.First("acls", "id", acl.ID) - if err != nil { - return fmt.Errorf("failed acl lookup: %s", err) - } - - // Set the indexes - if existing != nil { - acl.CreateIndex = existing.(*structs.ACL).CreateIndex - acl.ModifyIndex = idx - } else { - acl.CreateIndex = idx - acl.ModifyIndex = idx - } - - // Insert the ACL - if err := tx.Insert("acls", acl); err != nil { - return fmt.Errorf("failed inserting acl: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"acls", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - tx.Defer(func() { s.tableWatches["acls"].Notify() }) - return nil -} - -// ACLGet is used to look up an existing ACL by ID. -func (s *StateStore) ACLGet(aclID string) (uint64, *structs.ACL, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ACLGet")...) - - // Query for the existing ACL - acl, err := tx.First("acls", "id", aclID) - if err != nil { - return 0, nil, fmt.Errorf("failed acl lookup: %s", err) - } - if acl != nil { - return idx, acl.(*structs.ACL), nil - } - return idx, nil, nil -} - -// ACLList is used to list out all of the ACLs in the state store. -func (s *StateStore) ACLList() (uint64, structs.ACLs, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("ACLList")...) - - // Return the ACLs. - acls, err := s.aclListTxn(tx) - if err != nil { - return 0, nil, fmt.Errorf("failed acl lookup: %s", err) - } - return idx, acls, nil -} - -// aclListTxn is used to list out all of the ACLs in the state store. This is a -// function vs. a method so it can be called from the snapshotter. -func (s *StateStore) aclListTxn(tx *memdb.Txn) (structs.ACLs, error) { - // Query all of the ACLs in the state store - acls, err := tx.Get("acls", "id") - if err != nil { - return nil, fmt.Errorf("failed acl lookup: %s", err) - } - - // Go over all of the ACLs and build the response - var result structs.ACLs - for acl := acls.Next(); acl != nil; acl = acls.Next() { - a := acl.(*structs.ACL) - result = append(result, a) - } - return result, nil -} - -// ACLDelete is used to remove an existing ACL from the state store. If -// the ACL does not exist this is a no-op and no error is returned. -func (s *StateStore) ACLDelete(idx uint64, aclID string) error { - tx := s.db.Txn(true) - defer tx.Abort() - - // Call the ACL delete - if err := s.aclDeleteTxn(tx, idx, aclID); err != nil { - return err - } - - tx.Commit() - return nil -} - -// aclDeleteTxn is used to delete an ACL from the state store within -// an existing transaction. -func (s *StateStore) aclDeleteTxn(tx *memdb.Txn, idx uint64, aclID string) error { - // Look up the existing ACL - acl, err := tx.First("acls", "id", aclID) - if err != nil { - return fmt.Errorf("failed acl lookup: %s", err) - } - if acl == nil { - return nil - } - - // Delete the ACL from the state store and update indexes - if err := tx.Delete("acls", acl); err != nil { - return fmt.Errorf("failed deleting acl: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"acls", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - tx.Defer(func() { s.tableWatches["acls"].Notify() }) - return nil -} - -// CoordinateGetRaw queries for the coordinate of the given node. This is an -// unusual state store method because it just returns the raw coordinate or -// nil, none of the Raft or node information is returned. This hits the 90% -// internal-to-Consul use case for this data, and this isn't exposed via an -// endpoint, so it doesn't matter that the Raft info isn't available. -func (s *StateStore) CoordinateGetRaw(node string) (*coordinate.Coordinate, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Pull the full coordinate entry. - coord, err := tx.First("coordinates", "id", node) - if err != nil { - return nil, fmt.Errorf("failed coordinate lookup: %s", err) - } - - // Pick out just the raw coordinate. - if coord != nil { - return coord.(*structs.Coordinate).Coord, nil - } - return nil, nil -} - -// Coordinates queries for all nodes with coordinates. -func (s *StateStore) Coordinates() (uint64, structs.Coordinates, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table index. - idx := maxIndexTxn(tx, s.getWatchTables("Coordinates")...) - - // Pull all the coordinates. - coords, err := tx.Get("coordinates", "id") - if err != nil { - return 0, nil, fmt.Errorf("failed coordinate lookup: %s", err) - } - var results structs.Coordinates - for coord := coords.Next(); coord != nil; coord = coords.Next() { - results = append(results, coord.(*structs.Coordinate)) - } - return idx, results, nil -} - -// CoordinateBatchUpdate processes a batch of coordinate updates and applies -// them in a single transaction. -func (s *StateStore) CoordinateBatchUpdate(idx uint64, updates structs.Coordinates) error { - tx := s.db.Txn(true) - defer tx.Abort() - - // Upsert the coordinates. - for _, update := range updates { - // Since the cleanup of coordinates is tied to deletion of - // nodes, we silently drop any updates for nodes that we don't - // know about. This might be possible during normal operation - // if we happen to get a coordinate update for a node that - // hasn't been able to add itself to the catalog yet. Since we - // don't carefully sequence this, and since it will fix itself - // on the next coordinate update from that node, we don't return - // an error or log anything. - node, err := tx.First("nodes", "id", update.Node) - if err != nil { - return fmt.Errorf("failed node lookup: %s", err) - } - if node == nil { - continue - } - - if err := tx.Insert("coordinates", update); err != nil { - return fmt.Errorf("failed inserting coordinate: %s", err) - } - } - - // Update the index. - if err := tx.Insert("index", &IndexEntry{"coordinates", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - tx.Defer(func() { s.tableWatches["coordinates"].Notify() }) - tx.Commit() - return nil -} diff --git a/consul/state/state_store_test.go b/consul/state/state_store_test.go index 52ecb7e57f..687cc4c582 100644 --- a/consul/state/state_store_test.go +++ b/consul/state/state_store_test.go @@ -3,17 +3,10 @@ package state import ( crand "crypto/rand" "fmt" - "math/rand" - "reflect" - "sort" - "strings" "testing" - "time" "github.com/hashicorp/consul/consul/structs" - "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/types" - "github.com/hashicorp/serf/coordinate" ) func testUUID() string { @@ -230,3512 +223,3 @@ func TestStateStore_GetWatches(t *testing.T) { t.Fatalf("didn't get a watch") } } - -func TestStateStore_EnsureRegistration(t *testing.T) { - s := testStateStore(t) - - // Start with just a node. - req := &structs.RegisterRequest{ - Node: "node1", - Address: "1.2.3.4", - TaggedAddresses: map[string]string{ - "hello": "world", - }, - NodeMeta: map[string]string{ - "somekey": "somevalue", - }, - } - if err := s.EnsureRegistration(1, req); err != nil { - t.Fatalf("err: %s", err) - } - - // Retrieve the node and verify its contents. - verifyNode := func(created, modified uint64) { - _, out, err := s.GetNode("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if out.Node != "node1" || out.Address != "1.2.3.4" || - len(out.TaggedAddresses) != 1 || - out.TaggedAddresses["hello"] != "world" || - out.Meta["somekey"] != "somevalue" || - out.CreateIndex != created || out.ModifyIndex != modified { - t.Fatalf("bad node returned: %#v", out) - } - } - verifyNode(1, 1) - - // Add in a service definition. - req.Service = &structs.NodeService{ - ID: "redis1", - Service: "redis", - Address: "1.1.1.1", - Port: 8080, - } - if err := s.EnsureRegistration(2, req); err != nil { - t.Fatalf("err: %s", err) - } - - // Verify that the service got registered. - verifyService := func(created, modified uint64) { - idx, out, err := s.NodeServices("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != modified { - t.Fatalf("bad index: %d", idx) - } - if len(out.Services) != 1 { - t.Fatalf("bad: %#v", out.Services) - } - r := out.Services["redis1"] - if r == nil || r.ID != "redis1" || r.Service != "redis" || - r.Address != "1.1.1.1" || r.Port != 8080 || - r.CreateIndex != created || r.ModifyIndex != modified { - t.Fatalf("bad service returned: %#v", r) - } - - idx, r, err = s.NodeService("node1", "redis1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != modified { - t.Fatalf("bad index: %d", idx) - } - if r == nil || r.ID != "redis1" || r.Service != "redis" || - r.Address != "1.1.1.1" || r.Port != 8080 || - r.CreateIndex != created || r.ModifyIndex != modified { - t.Fatalf("bad service returned: %#v", r) - } - } - verifyNode(1, 2) - verifyService(2, 2) - - // Add in a top-level check. - req.Check = &structs.HealthCheck{ - Node: "node1", - CheckID: "check1", - Name: "check", - } - if err := s.EnsureRegistration(3, req); err != nil { - t.Fatalf("err: %s", err) - } - - // Verify that the check got registered. - verifyCheck := func(created, modified uint64) { - idx, out, err := s.NodeChecks("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != modified { - t.Fatalf("bad index: %d", idx) - } - if len(out) != 1 { - t.Fatalf("bad: %#v", out) - } - c := out[0] - if c.Node != "node1" || c.CheckID != "check1" || c.Name != "check" || - c.CreateIndex != created || c.ModifyIndex != modified { - t.Fatalf("bad check returned: %#v", c) - } - - idx, c, err = s.NodeCheck("node1", "check1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != modified { - t.Fatalf("bad index: %d", idx) - } - if c.Node != "node1" || c.CheckID != "check1" || c.Name != "check" || - c.CreateIndex != created || c.ModifyIndex != modified { - t.Fatalf("bad check returned: %#v", c) - } - } - verifyNode(1, 3) - verifyService(2, 3) - verifyCheck(3, 3) - - // Add in another check via the slice. - req.Checks = structs.HealthChecks{ - &structs.HealthCheck{ - Node: "node1", - CheckID: "check2", - Name: "check", - }, - } - if err := s.EnsureRegistration(4, req); err != nil { - t.Fatalf("err: %s", err) - } - - // Verify that the additional check got registered. - verifyNode(1, 4) - verifyService(2, 4) - func() { - idx, out, err := s.NodeChecks("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 4 { - t.Fatalf("bad index: %d", idx) - } - if len(out) != 2 { - t.Fatalf("bad: %#v", out) - } - c1 := out[0] - if c1.Node != "node1" || c1.CheckID != "check1" || c1.Name != "check" || - c1.CreateIndex != 3 || c1.ModifyIndex != 4 { - t.Fatalf("bad check returned: %#v", c1) - } - - c2 := out[1] - if c2.Node != "node1" || c2.CheckID != "check2" || c2.Name != "check" || - c2.CreateIndex != 4 || c2.ModifyIndex != 4 { - t.Fatalf("bad check returned: %#v", c2) - } - }() -} - -func TestStateStore_EnsureRegistration_Restore(t *testing.T) { - s := testStateStore(t) - - // Start with just a node. - req := &structs.RegisterRequest{ - Node: "node1", - Address: "1.2.3.4", - } - restore := s.Restore() - if err := restore.Registration(1, req); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - - // Retrieve the node and verify its contents. - verifyNode := func(created, modified uint64) { - _, out, err := s.GetNode("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if out.Node != "node1" || out.Address != "1.2.3.4" || - out.CreateIndex != created || out.ModifyIndex != modified { - t.Fatalf("bad node returned: %#v", out) - } - } - verifyNode(1, 1) - - // Add in a service definition. - req.Service = &structs.NodeService{ - ID: "redis1", - Service: "redis", - Address: "1.1.1.1", - Port: 8080, - } - restore = s.Restore() - if err := restore.Registration(2, req); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - - // Verify that the service got registered. - verifyService := func(created, modified uint64) { - idx, out, err := s.NodeServices("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != modified { - t.Fatalf("bad index: %d", idx) - } - if len(out.Services) != 1 { - t.Fatalf("bad: %#v", out.Services) - } - s := out.Services["redis1"] - if s.ID != "redis1" || s.Service != "redis" || - s.Address != "1.1.1.1" || s.Port != 8080 || - s.CreateIndex != created || s.ModifyIndex != modified { - t.Fatalf("bad service returned: %#v", s) - } - } - verifyNode(1, 2) - verifyService(2, 2) - - // Add in a top-level check. - req.Check = &structs.HealthCheck{ - Node: "node1", - CheckID: "check1", - Name: "check", - } - restore = s.Restore() - if err := restore.Registration(3, req); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - - // Verify that the check got registered. - verifyCheck := func(created, modified uint64) { - idx, out, err := s.NodeChecks("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != modified { - t.Fatalf("bad index: %d", idx) - } - if len(out) != 1 { - t.Fatalf("bad: %#v", out) - } - c := out[0] - if c.Node != "node1" || c.CheckID != "check1" || c.Name != "check" || - c.CreateIndex != created || c.ModifyIndex != modified { - t.Fatalf("bad check returned: %#v", c) - } - } - verifyNode(1, 3) - verifyService(2, 3) - verifyCheck(3, 3) - - // Add in another check via the slice. - req.Checks = structs.HealthChecks{ - &structs.HealthCheck{ - Node: "node1", - CheckID: "check2", - Name: "check", - }, - } - restore = s.Restore() - if err := restore.Registration(4, req); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - - // Verify that the additional check got registered. - verifyNode(1, 4) - verifyService(2, 4) - func() { - idx, out, err := s.NodeChecks("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 4 { - t.Fatalf("bad index: %d", idx) - } - if len(out) != 2 { - t.Fatalf("bad: %#v", out) - } - c1 := out[0] - if c1.Node != "node1" || c1.CheckID != "check1" || c1.Name != "check" || - c1.CreateIndex != 3 || c1.ModifyIndex != 4 { - t.Fatalf("bad check returned: %#v", c1) - } - - c2 := out[1] - if c2.Node != "node1" || c2.CheckID != "check2" || c2.Name != "check" || - c2.CreateIndex != 4 || c2.ModifyIndex != 4 { - t.Fatalf("bad check returned: %#v", c2) - } - }() -} - -func TestStateStore_EnsureRegistration_Watches(t *testing.T) { - s := testStateStore(t) - - req := &structs.RegisterRequest{ - Node: "node1", - Address: "1.2.3.4", - } - - // The nodes watch should fire for this one. - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyNoWatch(t, s.getTableWatch("services"), func() { - verifyNoWatch(t, s.getTableWatch("checks"), func() { - if err := s.EnsureRegistration(1, req); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - // The nodes watch should fire for this one. - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyNoWatch(t, s.getTableWatch("services"), func() { - verifyNoWatch(t, s.getTableWatch("checks"), func() { - restore := s.Restore() - if err := restore.Registration(1, req); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - }) - }) - }) - - // With a service definition added it should fire nodes and - // services. - req.Service = &structs.NodeService{ - ID: "redis1", - Service: "redis", - Address: "1.1.1.1", - Port: 8080, - } - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.getTableWatch("services"), func() { - verifyNoWatch(t, s.getTableWatch("checks"), func() { - if err := s.EnsureRegistration(2, req); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.getTableWatch("services"), func() { - verifyNoWatch(t, s.getTableWatch("checks"), func() { - restore := s.Restore() - if err := restore.Registration(2, req); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - }) - }) - }) - - // Now with a check it should hit all three. - req.Check = &structs.HealthCheck{ - Node: "node1", - CheckID: "check1", - Name: "check", - } - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.getTableWatch("services"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - if err := s.EnsureRegistration(3, req); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.getTableWatch("services"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - restore := s.Restore() - if err := restore.Registration(3, req); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - }) - }) - }) -} - -func TestStateStore_EnsureNode(t *testing.T) { - s := testStateStore(t) - - // Fetching a non-existent node returns nil - if _, node, err := s.GetNode("node1"); node != nil || err != nil { - t.Fatalf("expected (nil, nil), got: (%#v, %#v)", node, err) - } - - // Create a node registration request - in := &structs.Node{ - Node: "node1", - Address: "1.1.1.1", - } - - // Ensure the node is registered in the db - if err := s.EnsureNode(1, in); err != nil { - t.Fatalf("err: %s", err) - } - - // Retrieve the node again - idx, out, err := s.GetNode("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - - // Correct node was returned - if out.Node != "node1" || out.Address != "1.1.1.1" { - t.Fatalf("bad node returned: %#v", out) - } - - // Indexes are set properly - if out.CreateIndex != 1 || out.ModifyIndex != 1 { - t.Fatalf("bad node index: %#v", out) - } - if idx != 1 { - t.Fatalf("bad index: %d", idx) - } - - // Update the node registration - in.Address = "1.1.1.2" - if err := s.EnsureNode(2, in); err != nil { - t.Fatalf("err: %s", err) - } - - // Retrieve the node - idx, out, err = s.GetNode("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - - // Node and indexes were updated - if out.CreateIndex != 1 || out.ModifyIndex != 2 || out.Address != "1.1.1.2" { - t.Fatalf("bad: %#v", out) - } - if idx != 2 { - t.Fatalf("bad index: %d", idx) - } - - // Node upsert preserves the create index - if err := s.EnsureNode(3, in); err != nil { - t.Fatalf("err: %s", err) - } - idx, out, err = s.GetNode("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if out.CreateIndex != 1 || out.ModifyIndex != 3 || out.Address != "1.1.1.2" { - t.Fatalf("node was modified: %#v", out) - } - if idx != 3 { - t.Fatalf("bad index: %d", idx) - } -} - -func TestStateStore_GetNodes(t *testing.T) { - s := testStateStore(t) - - // Listing with no results returns nil - idx, res, err := s.Nodes() - if idx != 0 || res != nil || err != nil { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) - } - - // Create some nodes in the state store - testRegisterNode(t, s, 0, "node0") - testRegisterNode(t, s, 1, "node1") - testRegisterNode(t, s, 2, "node2") - - // Retrieve the nodes - idx, nodes, err := s.Nodes() - if err != nil { - t.Fatalf("err: %s", err) - } - - // Highest index was returned - if idx != 2 { - t.Fatalf("bad index: %d", idx) - } - - // All nodes were returned - if n := len(nodes); n != 3 { - t.Fatalf("bad node count: %d", n) - } - - // Make sure the nodes match - for i, node := range nodes { - if node.CreateIndex != uint64(i) || node.ModifyIndex != uint64(i) { - t.Fatalf("bad node index: %d, %d", node.CreateIndex, node.ModifyIndex) - } - name := fmt.Sprintf("node%d", i) - if node.Node != name { - t.Fatalf("bad: %#v", node) - } - } -} - -func BenchmarkGetNodes(b *testing.B) { - s, err := NewStateStore(nil) - if err != nil { - b.Fatalf("err: %s", err) - } - - if err := s.EnsureNode(100, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - b.Fatalf("err: %v", err) - } - if err := s.EnsureNode(101, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { - b.Fatalf("err: %v", err) - } - - for i := 0; i < b.N; i++ { - s.Nodes() - } -} - -func TestStateStore_GetNodesByMeta(t *testing.T) { - s := testStateStore(t) - - // Listing with no results returns nil - idx, res, err := s.NodesByMeta(map[string]string{"somekey": "somevalue"}) - if idx != 0 || res != nil || err != nil { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) - } - - // Create some nodes in the state store - node0 := &structs.Node{Node: "node0", Address: "127.0.0.1", Meta: map[string]string{"role": "client", "common": "1"}} - if err := s.EnsureNode(0, node0); err != nil { - t.Fatalf("err: %v", err) - } - node1 := &structs.Node{Node: "node1", Address: "127.0.0.1", Meta: map[string]string{"role": "server", "common": "1"}} - if err := s.EnsureNode(1, node1); err != nil { - t.Fatalf("err: %v", err) - } - - // Retrieve the node with role=client - idx, nodes, err := s.NodesByMeta(map[string]string{"role": "client"}) - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 1 { - t.Fatalf("bad index: %d", idx) - } - - // Only one node was returned - if n := len(nodes); n != 1 { - t.Fatalf("bad node count: %d", n) - } - - // Make sure the node is correct - if nodes[0].CreateIndex != 0 || nodes[0].ModifyIndex != 0 { - t.Fatalf("bad node index: %d, %d", nodes[0].CreateIndex, nodes[0].ModifyIndex) - } - if nodes[0].Node != "node0" { - t.Fatalf("bad: %#v", nodes[0]) - } - if !reflect.DeepEqual(nodes[0].Meta, node0.Meta) { - t.Fatalf("bad: %v != %v", nodes[0].Meta, node0.Meta) - } - - // Retrieve both nodes via their common meta field - idx, nodes, err = s.NodesByMeta(map[string]string{"common": "1"}) - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 1 { - t.Fatalf("bad index: %d", idx) - } - - // All nodes were returned - if n := len(nodes); n != 2 { - t.Fatalf("bad node count: %d", n) - } - - // Make sure the nodes match - for i, node := range nodes { - if node.CreateIndex != uint64(i) || node.ModifyIndex != uint64(i) { - t.Fatalf("bad node index: %d, %d", node.CreateIndex, node.ModifyIndex) - } - name := fmt.Sprintf("node%d", i) - if node.Node != name { - t.Fatalf("bad: %#v", node) - } - if v, ok := node.Meta["common"]; !ok || v != "1" { - t.Fatalf("bad: %v", node.Meta) - } - } -} - -func BenchmarkGetNodesByMeta(b *testing.B) { - s, err := NewStateStore(nil) - if err != nil { - b.Fatalf("err: %s", err) - } - - if err := s.EnsureNode(100, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - b.Fatalf("err: %v", err) - } - if err := s.EnsureNode(101, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { - b.Fatalf("err: %v", err) - } - - for i := 0; i < b.N; i++ { - s.Nodes() - } -} - -func TestStateStore_DeleteNode(t *testing.T) { - s := testStateStore(t) - - // Create a node and register a service and health check with it. - testRegisterNode(t, s, 0, "node1") - testRegisterService(t, s, 1, "node1", "service1") - testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) - - // Delete the node - if err := s.DeleteNode(3, "node1"); err != nil { - t.Fatalf("err: %s", err) - } - - // The node was removed - if idx, n, err := s.GetNode("node1"); err != nil || n != nil || idx != 3 { - t.Fatalf("bad: %#v %d (err: %#v)", n, idx, err) - } - - // Associated service was removed. Need to query this directly out of - // the DB to make sure it is actually gone. - tx := s.db.Txn(false) - defer tx.Abort() - services, err := tx.Get("services", "id", "node1", "service1") - if err != nil { - t.Fatalf("err: %s", err) - } - if service := services.Next(); service != nil { - t.Fatalf("bad: %#v", service) - } - - // Associated health check was removed. - checks, err := tx.Get("checks", "id", "node1", "check1") - if err != nil { - t.Fatalf("err: %s", err) - } - if check := checks.Next(); check != nil { - t.Fatalf("bad: %#v", check) - } - - // Indexes were updated. - for _, tbl := range []string{"nodes", "services", "checks"} { - if idx := s.maxIndex(tbl); idx != 3 { - t.Fatalf("bad index: %d (%s)", idx, tbl) - } - } - - // Deleting a nonexistent node should be idempotent and not return - // an error - if err := s.DeleteNode(4, "node1"); err != nil { - t.Fatalf("err: %s", err) - } - if idx := s.maxIndex("nodes"); idx != 3 { - t.Fatalf("bad index: %d", idx) - } -} - -func TestStateStore_Node_Snapshot(t *testing.T) { - s := testStateStore(t) - - // Create some nodes in the state store. - testRegisterNode(t, s, 0, "node0") - testRegisterNode(t, s, 1, "node1") - testRegisterNode(t, s, 2, "node2") - - // Snapshot the nodes. - snap := s.Snapshot() - defer snap.Close() - - // Alter the real state store. - testRegisterNode(t, s, 3, "node3") - - // Verify the snapshot. - if idx := snap.LastIndex(); idx != 2 { - t.Fatalf("bad index: %d", idx) - } - nodes, err := snap.Nodes() - if err != nil { - t.Fatalf("err: %s", err) - } - for i := 0; i < 3; i++ { - node := nodes.Next().(*structs.Node) - if node == nil { - t.Fatalf("unexpected end of nodes") - } - - if node.CreateIndex != uint64(i) || node.ModifyIndex != uint64(i) { - t.Fatalf("bad node index: %d, %d", node.CreateIndex, node.ModifyIndex) - } - if node.Node != fmt.Sprintf("node%d", i) { - t.Fatalf("bad: %#v", node) - } - } - if nodes.Next() != nil { - t.Fatalf("unexpected extra nodes") - } -} - -func TestStateStore_Node_Watches(t *testing.T) { - s := testStateStore(t) - - // Call functions that update the nodes table and make sure a watch fires - // each time. - verifyWatch(t, s.getTableWatch("nodes"), func() { - req := &structs.RegisterRequest{ - Node: "node1", - } - if err := s.EnsureRegistration(1, req); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("nodes"), func() { - node := &structs.Node{Node: "node2"} - if err := s.EnsureNode(2, node); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("nodes"), func() { - if err := s.DeleteNode(3, "node2"); err != nil { - t.Fatalf("err: %s", err) - } - }) - - // Check that a delete of a node + service + check + coordinate triggers - // all tables in one shot. - testRegisterNode(t, s, 4, "node1") - testRegisterService(t, s, 5, "node1", "service1") - testRegisterCheck(t, s, 6, "node1", "service1", "check3", structs.HealthPassing) - updates := structs.Coordinates{ - &structs.Coordinate{ - Node: "node1", - Coord: generateRandomCoordinate(), - }, - } - if err := s.CoordinateBatchUpdate(7, updates); err != nil { - t.Fatalf("err: %s", err) - } - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.getTableWatch("services"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - verifyWatch(t, s.getTableWatch("coordinates"), func() { - if err := s.DeleteNode(7, "node1"); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) - }) - }) -} - -func TestStateStore_EnsureService(t *testing.T) { - s := testStateStore(t) - - // Fetching services for a node with none returns nil - idx, res, err := s.NodeServices("node1") - if err != nil || res != nil || idx != 0 { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) - } - - // Create the service registration - ns1 := &structs.NodeService{ - ID: "service1", - Service: "redis", - Tags: []string{"prod"}, - Address: "1.1.1.1", - Port: 1111, - } - - // Creating a service without a node returns an error - if err := s.EnsureService(1, "node1", ns1); err != ErrMissingNode { - t.Fatalf("expected %#v, got: %#v", ErrMissingNode, err) - } - - // Register the nodes - testRegisterNode(t, s, 0, "node1") - testRegisterNode(t, s, 1, "node2") - - // Service successfully registers into the state store - if err = s.EnsureService(10, "node1", ns1); err != nil { - t.Fatalf("err: %s", err) - } - - // Register a similar service against both nodes - ns2 := *ns1 - ns2.ID = "service2" - for _, n := range []string{"node1", "node2"} { - if err := s.EnsureService(20, n, &ns2); err != nil { - t.Fatalf("err: %s", err) - } - } - - // Register a different service on the bad node - ns3 := *ns1 - ns3.ID = "service3" - if err := s.EnsureService(30, "node2", &ns3); err != nil { - t.Fatalf("err: %s", err) - } - - // Retrieve the services - idx, out, err := s.NodeServices("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 30 { - t.Fatalf("bad index: %d", idx) - } - - // Only the services for the requested node are returned - if out == nil || len(out.Services) != 2 { - t.Fatalf("bad services: %#v", out) - } - - // Results match the inserted services and have the proper indexes set - expect1 := *ns1 - expect1.CreateIndex, expect1.ModifyIndex = 10, 10 - if svc := out.Services["service1"]; !reflect.DeepEqual(&expect1, svc) { - t.Fatalf("bad: %#v", svc) - } - - expect2 := ns2 - expect2.CreateIndex, expect2.ModifyIndex = 20, 20 - if svc := out.Services["service2"]; !reflect.DeepEqual(&expect2, svc) { - t.Fatalf("bad: %#v %#v", ns2, svc) - } - - // Index tables were updated - if idx := s.maxIndex("services"); idx != 30 { - t.Fatalf("bad index: %d", idx) - } - - // Update a service registration - ns1.Address = "1.1.1.2" - if err := s.EnsureService(40, "node1", ns1); err != nil { - t.Fatalf("err: %s", err) - } - - // Retrieve the service again and ensure it matches - idx, out, err = s.NodeServices("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 40 { - t.Fatalf("bad index: %d", idx) - } - if out == nil || len(out.Services) != 2 { - t.Fatalf("bad: %#v", out) - } - expect1.Address = "1.1.1.2" - expect1.ModifyIndex = 40 - if svc := out.Services["service1"]; !reflect.DeepEqual(&expect1, svc) { - t.Fatalf("bad: %#v", svc) - } - - // Index tables were updated - if idx := s.maxIndex("services"); idx != 40 { - t.Fatalf("bad index: %d", idx) - } -} - -func TestStateStore_Services(t *testing.T) { - s := testStateStore(t) - - // Register several nodes and services. - testRegisterNode(t, s, 1, "node1") - ns1 := &structs.NodeService{ - ID: "service1", - Service: "redis", - Tags: []string{"prod", "master"}, - Address: "1.1.1.1", - Port: 1111, - } - if err := s.EnsureService(2, "node1", ns1); err != nil { - t.Fatalf("err: %s", err) - } - testRegisterService(t, s, 3, "node1", "dogs") - testRegisterNode(t, s, 4, "node2") - ns2 := &structs.NodeService{ - ID: "service3", - Service: "redis", - Tags: []string{"prod", "slave"}, - Address: "1.1.1.1", - Port: 1111, - } - if err := s.EnsureService(5, "node2", ns2); err != nil { - t.Fatalf("err: %s", err) - } - - // Pull all the services. - idx, services, err := s.Services() - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 5 { - t.Fatalf("bad index: %d", idx) - } - - // Verify the result. We sort the lists since the order is - // non-deterministic (it's built using a map internally). - expected := structs.Services{ - "redis": []string{"prod", "master", "slave"}, - "dogs": []string{}, - } - sort.Strings(expected["redis"]) - for _, tags := range services { - sort.Strings(tags) - } - if !reflect.DeepEqual(expected, services) { - t.Fatalf("bad: %#v", services) - } -} - -func TestStateStore_ServicesByNodeMeta(t *testing.T) { - s := testStateStore(t) - - // Listing with no results returns nil - idx, res, err := s.ServicesByNodeMeta(map[string]string{"somekey": "somevalue"}) - if idx != 0 || len(res) != 0 || err != nil { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) - } - - // Create some nodes and services in the state store - node0 := &structs.Node{Node: "node0", Address: "127.0.0.1", Meta: map[string]string{"role": "client", "common": "1"}} - if err := s.EnsureNode(0, node0); err != nil { - t.Fatalf("err: %v", err) - } - node1 := &structs.Node{Node: "node1", Address: "127.0.0.1", Meta: map[string]string{"role": "server", "common": "1"}} - if err := s.EnsureNode(1, node1); err != nil { - t.Fatalf("err: %v", err) - } - ns1 := &structs.NodeService{ - ID: "service1", - Service: "redis", - Tags: []string{"prod", "master"}, - Address: "1.1.1.1", - Port: 1111, - } - if err := s.EnsureService(2, "node0", ns1); err != nil { - t.Fatalf("err: %s", err) - } - ns2 := &structs.NodeService{ - ID: "service1", - Service: "redis", - Tags: []string{"prod", "slave"}, - Address: "1.1.1.1", - Port: 1111, - } - if err := s.EnsureService(3, "node1", ns2); err != nil { - t.Fatalf("err: %s", err) - } - - // Filter the services by the first node's meta value - idx, res, err = s.ServicesByNodeMeta(map[string]string{"role": "client"}) - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 3 { - t.Fatalf("bad index: %d", idx) - } - expected := structs.Services{ - "redis": []string{"master", "prod"}, - } - sort.Strings(res["redis"]) - if !reflect.DeepEqual(res, expected) { - t.Fatalf("bad: %v %v", res, expected) - } - - // Get all services using the common meta value - idx, res, err = s.ServicesByNodeMeta(map[string]string{"common": "1"}) - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 3 { - t.Fatalf("bad index: %d", idx) - } - expected = structs.Services{ - "redis": []string{"master", "prod", "slave"}, - } - sort.Strings(res["redis"]) - if !reflect.DeepEqual(res, expected) { - t.Fatalf("bad: %v %v", res, expected) - } -} - -func TestStateStore_ServiceNodes(t *testing.T) { - s := testStateStore(t) - - if err := s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureService(12, "foo", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureService(14, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureService(15, "bar", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8000}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, nodes, err := s.ServiceNodes("db") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 16 { - t.Fatalf("bad: %v", 16) - } - if len(nodes) != 3 { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Node != "bar" { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Address != "127.0.0.2" { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].ServiceID != "db" { - t.Fatalf("bad: %v", nodes) - } - if !lib.StrContains(nodes[0].ServiceTags, "slave") { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].ServicePort != 8000 { - t.Fatalf("bad: %v", nodes) - } - - if nodes[1].Node != "bar" { - t.Fatalf("bad: %v", nodes) - } - if nodes[1].Address != "127.0.0.2" { - t.Fatalf("bad: %v", nodes) - } - if nodes[1].ServiceID != "db2" { - t.Fatalf("bad: %v", nodes) - } - if !lib.StrContains(nodes[1].ServiceTags, "slave") { - t.Fatalf("bad: %v", nodes) - } - if nodes[1].ServicePort != 8001 { - t.Fatalf("bad: %v", nodes) - } - - if nodes[2].Node != "foo" { - t.Fatalf("bad: %v", nodes) - } - if nodes[2].Address != "127.0.0.1" { - t.Fatalf("bad: %v", nodes) - } - if nodes[2].ServiceID != "db" { - t.Fatalf("bad: %v", nodes) - } - if !lib.StrContains(nodes[2].ServiceTags, "master") { - t.Fatalf("bad: %v", nodes) - } - if nodes[2].ServicePort != 8000 { - t.Fatalf("bad: %v", nodes) - } -} - -func TestStateStore_ServiceTagNodes(t *testing.T) { - s := testStateStore(t) - - if err := s.EnsureNode(15, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureNode(16, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureService(17, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureService(18, "foo", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureService(19, "bar", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8000}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, nodes, err := s.ServiceTagNodes("db", "master") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 19 { - t.Fatalf("bad: %v", idx) - } - if len(nodes) != 1 { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Node != "foo" { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Address != "127.0.0.1" { - t.Fatalf("bad: %v", nodes) - } - if !lib.StrContains(nodes[0].ServiceTags, "master") { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].ServicePort != 8000 { - t.Fatalf("bad: %v", nodes) - } -} - -func TestStateStore_ServiceTagNodes_MultipleTags(t *testing.T) { - s := testStateStore(t) - - if err := s.EnsureNode(15, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureNode(16, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureService(17, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"master", "v2"}, Address: "", Port: 8000}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureService(18, "foo", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave", "v2", "dev"}, Address: "", Port: 8001}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := s.EnsureService(19, "bar", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave", "v2"}, Address: "", Port: 8000}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, nodes, err := s.ServiceTagNodes("db", "master") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 19 { - t.Fatalf("bad: %v", idx) - } - if len(nodes) != 1 { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Node != "foo" { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Address != "127.0.0.1" { - t.Fatalf("bad: %v", nodes) - } - if !lib.StrContains(nodes[0].ServiceTags, "master") { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].ServicePort != 8000 { - t.Fatalf("bad: %v", nodes) - } - - idx, nodes, err = s.ServiceTagNodes("db", "v2") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 19 { - t.Fatalf("bad: %v", idx) - } - if len(nodes) != 3 { - t.Fatalf("bad: %v", nodes) - } - - idx, nodes, err = s.ServiceTagNodes("db", "dev") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 19 { - t.Fatalf("bad: %v", idx) - } - if len(nodes) != 1 { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Node != "foo" { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Address != "127.0.0.1" { - t.Fatalf("bad: %v", nodes) - } - if !lib.StrContains(nodes[0].ServiceTags, "dev") { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].ServicePort != 8001 { - t.Fatalf("bad: %v", nodes) - } -} - -func TestStateStore_DeleteService(t *testing.T) { - s := testStateStore(t) - - // Register a node with one service and a check - testRegisterNode(t, s, 1, "node1") - testRegisterService(t, s, 2, "node1", "service1") - testRegisterCheck(t, s, 3, "node1", "service1", "check1", structs.HealthPassing) - - // Delete the service - if err := s.DeleteService(4, "node1", "service1"); err != nil { - t.Fatalf("err: %s", err) - } - - // Service doesn't exist. - _, ns, err := s.NodeServices("node1") - if err != nil || ns == nil || len(ns.Services) != 0 { - t.Fatalf("bad: %#v (err: %#v)", ns, err) - } - - // Check doesn't exist. Check using the raw DB so we can test - // that it actually is removed in the state store. - tx := s.db.Txn(false) - defer tx.Abort() - check, err := tx.First("checks", "id", "node1", "check1") - if err != nil || check != nil { - t.Fatalf("bad: %#v (err: %s)", check, err) - } - - // Index tables were updated - if idx := s.maxIndex("services"); idx != 4 { - t.Fatalf("bad index: %d", idx) - } - if idx := s.maxIndex("checks"); idx != 4 { - t.Fatalf("bad index: %d", idx) - } - - // Deleting a nonexistent service should be idempotent and not return an - // error - if err := s.DeleteService(5, "node1", "service1"); err != nil { - t.Fatalf("err: %s", err) - } - if idx := s.maxIndex("services"); idx != 4 { - t.Fatalf("bad index: %d", idx) - } -} - -func TestStateStore_Service_Snapshot(t *testing.T) { - s := testStateStore(t) - - // Register a node with two services. - testRegisterNode(t, s, 0, "node1") - ns := []*structs.NodeService{ - &structs.NodeService{ - ID: "service1", - Service: "redis", - Tags: []string{"prod"}, - Address: "1.1.1.1", - Port: 1111, - }, - &structs.NodeService{ - ID: "service2", - Service: "nomad", - Tags: []string{"dev"}, - Address: "1.1.1.2", - Port: 1112, - }, - } - for i, svc := range ns { - if err := s.EnsureService(uint64(i+1), "node1", svc); err != nil { - t.Fatalf("err: %s", err) - } - } - - // Create a second node/service to make sure node filtering works. This - // will affect the index but not the dump. - testRegisterNode(t, s, 3, "node2") - testRegisterService(t, s, 4, "node2", "service2") - - // Snapshot the service. - snap := s.Snapshot() - defer snap.Close() - - // Alter the real state store. - testRegisterService(t, s, 5, "node2", "service3") - - // Verify the snapshot. - if idx := snap.LastIndex(); idx != 4 { - t.Fatalf("bad index: %d", idx) - } - services, err := snap.Services("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - for i := 0; i < len(ns); i++ { - svc := services.Next().(*structs.ServiceNode) - if svc == nil { - t.Fatalf("unexpected end of services") - } - - ns[i].CreateIndex, ns[i].ModifyIndex = uint64(i+1), uint64(i+1) - if !reflect.DeepEqual(ns[i], svc.ToNodeService()) { - t.Fatalf("bad: %#v != %#v", svc, ns[i]) - } - } - if services.Next() != nil { - t.Fatalf("unexpected extra services") - } -} - -func TestStateStore_Service_Watches(t *testing.T) { - s := testStateStore(t) - - testRegisterNode(t, s, 0, "node1") - ns := &structs.NodeService{ - ID: "service2", - Service: "nomad", - Address: "1.1.1.2", - Port: 8000, - } - - // Call functions that update the services table and make sure a watch - // fires each time. - verifyWatch(t, s.getTableWatch("services"), func() { - if err := s.EnsureService(2, "node1", ns); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("services"), func() { - if err := s.DeleteService(3, "node1", "service2"); err != nil { - t.Fatalf("err: %s", err) - } - }) - - // Check that a delete of a service + check triggers both tables in one - // shot. - testRegisterService(t, s, 4, "node1", "service1") - testRegisterCheck(t, s, 5, "node1", "service1", "check3", structs.HealthPassing) - verifyWatch(t, s.getTableWatch("services"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - if err := s.DeleteService(6, "node1", "service1"); err != nil { - t.Fatalf("err: %s", err) - } - }) - }) -} - -func TestStateStore_EnsureCheck(t *testing.T) { - s := testStateStore(t) - - // Create a check associated with the node - check := &structs.HealthCheck{ - Node: "node1", - CheckID: "check1", - Name: "redis check", - Status: structs.HealthPassing, - Notes: "test check", - Output: "aaa", - ServiceID: "service1", - ServiceName: "redis", - } - - // Creating a check without a node returns error - if err := s.EnsureCheck(1, check); err != ErrMissingNode { - t.Fatalf("expected %#v, got: %#v", ErrMissingNode, err) - } - - // Register the node - testRegisterNode(t, s, 1, "node1") - - // Creating a check with a bad services returns error - if err := s.EnsureCheck(1, check); err != ErrMissingService { - t.Fatalf("expected: %#v, got: %#v", ErrMissingService, err) - } - - // Register the service - testRegisterService(t, s, 2, "node1", "service1") - - // Inserting the check with the prerequisites succeeds - if err := s.EnsureCheck(3, check); err != nil { - t.Fatalf("err: %s", err) - } - - // Retrieve the check and make sure it matches - idx, checks, err := s.NodeChecks("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 3 { - t.Fatalf("bad index: %d", idx) - } - if len(checks) != 1 { - t.Fatalf("wrong number of checks: %d", len(checks)) - } - if !reflect.DeepEqual(checks[0], check) { - t.Fatalf("bad: %#v", checks[0]) - } - - // Modify the health check - check.Output = "bbb" - if err := s.EnsureCheck(4, check); err != nil { - t.Fatalf("err: %s", err) - } - - // Check that we successfully updated - idx, checks, err = s.NodeChecks("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 4 { - t.Fatalf("bad index: %d", idx) - } - if len(checks) != 1 { - t.Fatalf("wrong number of checks: %d", len(checks)) - } - if checks[0].Output != "bbb" { - t.Fatalf("wrong check output: %#v", checks[0]) - } - if checks[0].CreateIndex != 3 || checks[0].ModifyIndex != 4 { - t.Fatalf("bad index: %#v", checks[0]) - } - - // Index tables were updated - if idx := s.maxIndex("checks"); idx != 4 { - t.Fatalf("bad index: %d", idx) - } -} - -func TestStateStore_EnsureCheck_defaultStatus(t *testing.T) { - s := testStateStore(t) - - // Register a node - testRegisterNode(t, s, 1, "node1") - - // Create and register a check with no health status - check := &structs.HealthCheck{ - Node: "node1", - CheckID: "check1", - Status: "", - } - if err := s.EnsureCheck(2, check); err != nil { - t.Fatalf("err: %s", err) - } - - // Get the check again - _, result, err := s.NodeChecks("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - - // Check that the status was set to the proper default - if len(result) != 1 || result[0].Status != structs.HealthCritical { - t.Fatalf("bad: %#v", result) - } -} - -func TestStateStore_NodeChecks(t *testing.T) { - s := testStateStore(t) - - // Create the first node and service with some checks - testRegisterNode(t, s, 0, "node1") - testRegisterService(t, s, 1, "node1", "service1") - testRegisterCheck(t, s, 2, "node1", "service1", "check1", structs.HealthPassing) - testRegisterCheck(t, s, 3, "node1", "service1", "check2", structs.HealthPassing) - - // Create a second node/service with a different set of checks - testRegisterNode(t, s, 4, "node2") - testRegisterService(t, s, 5, "node2", "service2") - testRegisterCheck(t, s, 6, "node2", "service2", "check3", structs.HealthPassing) - - // Try querying for all checks associated with node1 - idx, checks, err := s.NodeChecks("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } - if len(checks) != 2 || checks[0].CheckID != "check1" || checks[1].CheckID != "check2" { - t.Fatalf("bad checks: %#v", checks) - } - - // Try querying for all checks associated with node2 - idx, checks, err = s.NodeChecks("node2") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } - if len(checks) != 1 || checks[0].CheckID != "check3" { - t.Fatalf("bad checks: %#v", checks) - } -} - -func TestStateStore_ServiceChecks(t *testing.T) { - s := testStateStore(t) - - // Create the first node and service with some checks - testRegisterNode(t, s, 0, "node1") - testRegisterService(t, s, 1, "node1", "service1") - testRegisterCheck(t, s, 2, "node1", "service1", "check1", structs.HealthPassing) - testRegisterCheck(t, s, 3, "node1", "service1", "check2", structs.HealthPassing) - - // Create a second node/service with a different set of checks - testRegisterNode(t, s, 4, "node2") - testRegisterService(t, s, 5, "node2", "service2") - testRegisterCheck(t, s, 6, "node2", "service2", "check3", structs.HealthPassing) - - // Try querying for all checks associated with service1 - idx, checks, err := s.ServiceChecks("service1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } - if len(checks) != 2 || checks[0].CheckID != "check1" || checks[1].CheckID != "check2" { - t.Fatalf("bad checks: %#v", checks) - } -} - -func TestStateStore_ChecksInState(t *testing.T) { - s := testStateStore(t) - - // Querying with no results returns nil - idx, res, err := s.ChecksInState(structs.HealthPassing) - if idx != 0 || res != nil || err != nil { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) - } - - // Register a node with checks in varied states - testRegisterNode(t, s, 0, "node1") - testRegisterCheck(t, s, 1, "node1", "", "check1", structs.HealthPassing) - testRegisterCheck(t, s, 2, "node1", "", "check2", structs.HealthCritical) - testRegisterCheck(t, s, 3, "node1", "", "check3", structs.HealthPassing) - - // Query the state store for passing checks. - _, checks, err := s.ChecksInState(structs.HealthPassing) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Make sure we only get the checks which match the state - if n := len(checks); n != 2 { - t.Fatalf("expected 2 checks, got: %d", n) - } - if checks[0].CheckID != "check1" || checks[1].CheckID != "check3" { - t.Fatalf("bad: %#v", checks) - } - - // HealthAny just returns everything. - _, checks, err = s.ChecksInState(structs.HealthAny) - if err != nil { - t.Fatalf("err: %s", err) - } - if n := len(checks); n != 3 { - t.Fatalf("expected 3 checks, got: %d", n) - } -} - -func TestStateStore_DeleteCheck(t *testing.T) { - s := testStateStore(t) - - // Register a node and a node-level health check - testRegisterNode(t, s, 1, "node1") - testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) - - // Delete the check - if err := s.DeleteCheck(3, "node1", "check1"); err != nil { - t.Fatalf("err: %s", err) - } - - // Check is gone - _, checks, err := s.NodeChecks("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if len(checks) != 0 { - t.Fatalf("bad: %#v", checks) - } - - // Index tables were updated - if idx := s.maxIndex("checks"); idx != 3 { - t.Fatalf("bad index: %d", idx) - } - - // Deleting a nonexistent check should be idempotent and not return an - // error - if err := s.DeleteCheck(4, "node1", "check1"); err != nil { - t.Fatalf("err: %s", err) - } - if idx := s.maxIndex("checks"); idx != 3 { - t.Fatalf("bad index: %d", idx) - } -} - -func TestStateStore_CheckServiceNodes(t *testing.T) { - s := testStateStore(t) - - // Querying with no matches gives an empty response - idx, res, err := s.CheckServiceNodes("service1") - if idx != 0 || res != nil || err != nil { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) - } - - // Register some nodes - testRegisterNode(t, s, 0, "node1") - testRegisterNode(t, s, 1, "node2") - - // Register node-level checks. These should not be returned - // in the final result. - testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) - testRegisterCheck(t, s, 3, "node2", "", "check2", structs.HealthPassing) - - // Register a service against the nodes - testRegisterService(t, s, 4, "node1", "service1") - testRegisterService(t, s, 5, "node2", "service2") - - // Register checks against the services - testRegisterCheck(t, s, 6, "node1", "service1", "check3", structs.HealthPassing) - testRegisterCheck(t, s, 7, "node2", "service2", "check4", structs.HealthPassing) - - // Query the state store for nodes and checks which - // have been registered with a specific service. - idx, results, err := s.CheckServiceNodes("service1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 7 { - t.Fatalf("bad index: %d", idx) - } - - // Make sure we get the expected result (service check + node check) - if n := len(results); n != 1 { - t.Fatalf("expected 1 result, got: %d", n) - } - csn := results[0] - if csn.Node == nil || csn.Service == nil || len(csn.Checks) != 2 { - t.Fatalf("bad output: %#v", csn) - } - - // Node updates alter the returned index - testRegisterNode(t, s, 8, "node1") - idx, results, err = s.CheckServiceNodes("service1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 8 { - t.Fatalf("bad index: %d", idx) - } - - // Service updates alter the returned index - testRegisterService(t, s, 9, "node1", "service1") - idx, results, err = s.CheckServiceNodes("service1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 9 { - t.Fatalf("bad index: %d", idx) - } - - // Check updates alter the returned index - testRegisterCheck(t, s, 10, "node1", "service1", "check1", structs.HealthCritical) - idx, results, err = s.CheckServiceNodes("service1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 10 { - t.Fatalf("bad index: %d", idx) - } -} - -func BenchmarkCheckServiceNodes(b *testing.B) { - s, err := NewStateStore(nil) - if err != nil { - b.Fatalf("err: %s", err) - } - - if err := s.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - b.Fatalf("err: %v", err) - } - if err := s.EnsureService(2, "foo", &structs.NodeService{ID: "db1", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { - b.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "db", - Name: "can connect", - Status: structs.HealthPassing, - ServiceID: "db1", - } - if err := s.EnsureCheck(3, check); err != nil { - b.Fatalf("err: %v", err) - } - check = &structs.HealthCheck{ - Node: "foo", - CheckID: "check1", - Name: "check1", - Status: structs.HealthPassing, - } - if err := s.EnsureCheck(4, check); err != nil { - b.Fatalf("err: %v", err) - } - - for i := 0; i < b.N; i++ { - s.CheckServiceNodes("db") - } -} - -func TestStateStore_CheckServiceTagNodes(t *testing.T) { - s := testStateStore(t) - - if err := s.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - if err := s.EnsureService(2, "foo", &structs.NodeService{ID: "db1", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "db", - Name: "can connect", - Status: structs.HealthPassing, - ServiceID: "db1", - } - if err := s.EnsureCheck(3, check); err != nil { - t.Fatalf("err: %v", err) - } - check = &structs.HealthCheck{ - Node: "foo", - CheckID: "check1", - Name: "another check", - Status: structs.HealthPassing, - } - if err := s.EnsureCheck(4, check); err != nil { - t.Fatalf("err: %v", err) - } - - idx, nodes, err := s.CheckServiceTagNodes("db", "master") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 4 { - t.Fatalf("bad: %v", idx) - } - if len(nodes) != 1 { - t.Fatalf("Bad: %v", nodes) - } - if nodes[0].Node.Node != "foo" { - t.Fatalf("Bad: %v", nodes[0]) - } - if nodes[0].Service.ID != "db1" { - t.Fatalf("Bad: %v", nodes[0]) - } - if len(nodes[0].Checks) != 2 { - t.Fatalf("Bad: %v", nodes[0]) - } - if nodes[0].Checks[0].CheckID != "check1" { - t.Fatalf("Bad: %v", nodes[0]) - } - if nodes[0].Checks[1].CheckID != "db" { - t.Fatalf("Bad: %v", nodes[0]) - } -} - -func TestStateStore_Check_Snapshot(t *testing.T) { - s := testStateStore(t) - - // Create a node, a service, and a service check as well as a node check. - testRegisterNode(t, s, 0, "node1") - testRegisterService(t, s, 1, "node1", "service1") - checks := structs.HealthChecks{ - &structs.HealthCheck{ - Node: "node1", - CheckID: "check1", - Name: "node check", - Status: structs.HealthPassing, - }, - &structs.HealthCheck{ - Node: "node1", - CheckID: "check2", - Name: "service check", - Status: structs.HealthCritical, - ServiceID: "service1", - }, - } - for i, hc := range checks { - if err := s.EnsureCheck(uint64(i+1), hc); err != nil { - t.Fatalf("err: %s", err) - } - } - - // Create a second node/service to make sure node filtering works. This - // will affect the index but not the dump. - testRegisterNode(t, s, 3, "node2") - testRegisterService(t, s, 4, "node2", "service2") - testRegisterCheck(t, s, 5, "node2", "service2", "check3", structs.HealthPassing) - - // Snapshot the checks. - snap := s.Snapshot() - defer snap.Close() - - // Alter the real state store. - testRegisterCheck(t, s, 6, "node2", "service2", "check4", structs.HealthPassing) - - // Verify the snapshot. - if idx := snap.LastIndex(); idx != 5 { - t.Fatalf("bad index: %d", idx) - } - iter, err := snap.Checks("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - for i := 0; i < len(checks); i++ { - check := iter.Next().(*structs.HealthCheck) - if check == nil { - t.Fatalf("unexpected end of checks") - } - - checks[i].CreateIndex, checks[i].ModifyIndex = uint64(i+1), uint64(i+1) - if !reflect.DeepEqual(check, checks[i]) { - t.Fatalf("bad: %#v != %#v", check, checks[i]) - } - } - if iter.Next() != nil { - t.Fatalf("unexpected extra checks") - } -} - -func TestStateStore_Check_Watches(t *testing.T) { - s := testStateStore(t) - - testRegisterNode(t, s, 0, "node1") - hc := &structs.HealthCheck{ - Node: "node1", - CheckID: "check1", - Status: structs.HealthPassing, - } - - // Call functions that update the checks table and make sure a watch fires - // each time. - verifyWatch(t, s.getTableWatch("checks"), func() { - if err := s.EnsureCheck(1, hc); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("checks"), func() { - hc.Status = structs.HealthCritical - if err := s.EnsureCheck(2, hc); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("checks"), func() { - if err := s.DeleteCheck(3, "node1", "check1"); err != nil { - t.Fatalf("err: %s", err) - } - }) -} - -func TestStateStore_NodeInfo_NodeDump(t *testing.T) { - s := testStateStore(t) - - // Generating a node dump that matches nothing returns empty - idx, dump, err := s.NodeInfo("node1") - if idx != 0 || dump != nil || err != nil { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, dump, err) - } - idx, dump, err = s.NodeDump() - if idx != 0 || dump != nil || err != nil { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, dump, err) - } - - // Register some nodes - testRegisterNode(t, s, 0, "node1") - testRegisterNode(t, s, 1, "node2") - - // Register services against them - testRegisterService(t, s, 2, "node1", "service1") - testRegisterService(t, s, 3, "node1", "service2") - testRegisterService(t, s, 4, "node2", "service1") - testRegisterService(t, s, 5, "node2", "service2") - - // Register service-level checks - testRegisterCheck(t, s, 6, "node1", "service1", "check1", structs.HealthPassing) - testRegisterCheck(t, s, 7, "node2", "service1", "check1", structs.HealthPassing) - - // Register node-level checks - testRegisterCheck(t, s, 8, "node1", "", "check2", structs.HealthPassing) - testRegisterCheck(t, s, 9, "node2", "", "check2", structs.HealthPassing) - - // Check that our result matches what we expect. - expect := structs.NodeDump{ - &structs.NodeInfo{ - Node: "node1", - Checks: structs.HealthChecks{ - &structs.HealthCheck{ - Node: "node1", - CheckID: "check1", - ServiceID: "service1", - ServiceName: "service1", - Status: structs.HealthPassing, - RaftIndex: structs.RaftIndex{ - CreateIndex: 6, - ModifyIndex: 6, - }, - }, - &structs.HealthCheck{ - Node: "node1", - CheckID: "check2", - ServiceID: "", - ServiceName: "", - Status: structs.HealthPassing, - RaftIndex: structs.RaftIndex{ - CreateIndex: 8, - ModifyIndex: 8, - }, - }, - }, - Services: []*structs.NodeService{ - &structs.NodeService{ - ID: "service1", - Service: "service1", - Address: "1.1.1.1", - Port: 1111, - RaftIndex: structs.RaftIndex{ - CreateIndex: 2, - ModifyIndex: 2, - }, - }, - &structs.NodeService{ - ID: "service2", - Service: "service2", - Address: "1.1.1.1", - Port: 1111, - RaftIndex: structs.RaftIndex{ - CreateIndex: 3, - ModifyIndex: 3, - }, - }, - }, - }, - &structs.NodeInfo{ - Node: "node2", - Checks: structs.HealthChecks{ - &structs.HealthCheck{ - Node: "node2", - CheckID: "check1", - ServiceID: "service1", - ServiceName: "service1", - Status: structs.HealthPassing, - RaftIndex: structs.RaftIndex{ - CreateIndex: 7, - ModifyIndex: 7, - }, - }, - &structs.HealthCheck{ - Node: "node2", - CheckID: "check2", - ServiceID: "", - ServiceName: "", - Status: structs.HealthPassing, - RaftIndex: structs.RaftIndex{ - CreateIndex: 9, - ModifyIndex: 9, - }, - }, - }, - Services: []*structs.NodeService{ - &structs.NodeService{ - ID: "service1", - Service: "service1", - Address: "1.1.1.1", - Port: 1111, - RaftIndex: structs.RaftIndex{ - CreateIndex: 4, - ModifyIndex: 4, - }, - }, - &structs.NodeService{ - ID: "service2", - Service: "service2", - Address: "1.1.1.1", - Port: 1111, - RaftIndex: structs.RaftIndex{ - CreateIndex: 5, - ModifyIndex: 5, - }, - }, - }, - }, - } - - // Get a dump of just a single node - idx, dump, err = s.NodeInfo("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 9 { - t.Fatalf("bad index: %d", idx) - } - if len(dump) != 1 || !reflect.DeepEqual(dump[0], expect[0]) { - t.Fatalf("bad: %#v", dump) - } - - // Generate a dump of all the nodes - idx, dump, err = s.NodeDump() - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 9 { - t.Fatalf("bad index: %d", 9) - } - if !reflect.DeepEqual(dump, expect) { - t.Fatalf("bad: %#v", dump[0].Services[0]) - } -} - -func TestStateStore_SessionCreate_SessionGet(t *testing.T) { - s := testStateStore(t) - - // SessionGet returns nil if the session doesn't exist - idx, session, err := s.SessionGet(testUUID()) - if session != nil || err != nil { - t.Fatalf("expected (nil, nil), got: (%#v, %#v)", session, err) - } - if idx != 0 { - t.Fatalf("bad index: %d", idx) - } - - // Registering without a session ID is disallowed - err = s.SessionCreate(1, &structs.Session{}) - if err != ErrMissingSessionID { - t.Fatalf("expected %#v, got: %#v", ErrMissingSessionID, err) - } - - // Invalid session behavior throws error - sess := &structs.Session{ - ID: testUUID(), - Behavior: "nope", - } - err = s.SessionCreate(1, sess) - if err == nil || !strings.Contains(err.Error(), "session behavior") { - t.Fatalf("expected session behavior error, got: %#v", err) - } - - // Registering with an unknown node is disallowed - sess = &structs.Session{ID: testUUID()} - if err := s.SessionCreate(1, sess); err != ErrMissingNode { - t.Fatalf("expected %#v, got: %#v", ErrMissingNode, err) - } - - // None of the errored operations modified the index - if idx := s.maxIndex("sessions"); idx != 0 { - t.Fatalf("bad index: %d", idx) - } - - // Valid session is able to register - testRegisterNode(t, s, 1, "node1") - sess = &structs.Session{ - ID: testUUID(), - Node: "node1", - } - if err := s.SessionCreate(2, sess); err != nil { - t.Fatalf("err: %s", err) - } - if idx := s.maxIndex("sessions"); idx != 2 { - t.Fatalf("bad index: %s", err) - } - - // Retrieve the session again - idx, session, err = s.SessionGet(sess.ID) - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 2 { - t.Fatalf("bad index: %d", idx) - } - - // Ensure the session looks correct and was assigned the - // proper default value for session behavior. - expect := &structs.Session{ - ID: sess.ID, - Behavior: structs.SessionKeysRelease, - Node: "node1", - RaftIndex: structs.RaftIndex{ - CreateIndex: 2, - ModifyIndex: 2, - }, - } - if !reflect.DeepEqual(expect, session) { - t.Fatalf("bad session: %#v", session) - } - - // Registering with a non-existent check is disallowed - sess = &structs.Session{ - ID: testUUID(), - Node: "node1", - Checks: []types.CheckID{"check1"}, - } - err = s.SessionCreate(3, sess) - if err == nil || !strings.Contains(err.Error(), "Missing check") { - t.Fatalf("expected missing check error, got: %#v", err) - } - - // Registering with a critical check is disallowed - testRegisterCheck(t, s, 3, "node1", "", "check1", structs.HealthCritical) - err = s.SessionCreate(4, sess) - if err == nil || !strings.Contains(err.Error(), structs.HealthCritical) { - t.Fatalf("expected critical state error, got: %#v", err) - } - - // Registering with a healthy check succeeds - testRegisterCheck(t, s, 4, "node1", "", "check1", structs.HealthPassing) - if err := s.SessionCreate(5, sess); err != nil { - t.Fatalf("err: %s", err) - } - - // Register a session against two checks. - testRegisterCheck(t, s, 5, "node1", "", "check2", structs.HealthPassing) - sess2 := &structs.Session{ - ID: testUUID(), - Node: "node1", - Checks: []types.CheckID{"check1", "check2"}, - } - if err := s.SessionCreate(6, sess2); err != nil { - t.Fatalf("err: %s", err) - } - - tx := s.db.Txn(false) - defer tx.Abort() - - // Check mappings were inserted - { - check, err := tx.First("session_checks", "session", sess.ID) - if err != nil { - t.Fatalf("err: %s", err) - } - if check == nil { - t.Fatalf("missing session check") - } - expectCheck := &sessionCheck{ - Node: "node1", - CheckID: "check1", - Session: sess.ID, - } - if actual := check.(*sessionCheck); !reflect.DeepEqual(actual, expectCheck) { - t.Fatalf("expected %#v, got: %#v", expectCheck, actual) - } - } - checks, err := tx.Get("session_checks", "session", sess2.ID) - if err != nil { - t.Fatalf("err: %s", err) - } - for i, check := 0, checks.Next(); check != nil; i, check = i+1, checks.Next() { - expectCheck := &sessionCheck{ - Node: "node1", - CheckID: types.CheckID(fmt.Sprintf("check%d", i+1)), - Session: sess2.ID, - } - if actual := check.(*sessionCheck); !reflect.DeepEqual(actual, expectCheck) { - t.Fatalf("expected %#v, got: %#v", expectCheck, actual) - } - } - - // Pulling a nonexistent session gives the table index. - idx, session, err = s.SessionGet(testUUID()) - if err != nil { - t.Fatalf("err: %s", err) - } - if session != nil { - t.Fatalf("expected not to get a session: %v", session) - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } -} - -func TegstStateStore_SessionList(t *testing.T) { - s := testStateStore(t) - - // Listing when no sessions exist returns nil - idx, res, err := s.SessionList() - if idx != 0 || res != nil || err != nil { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) - } - - // Register some nodes - testRegisterNode(t, s, 1, "node1") - testRegisterNode(t, s, 2, "node2") - testRegisterNode(t, s, 3, "node3") - - // Create some sessions in the state store - sessions := structs.Sessions{ - &structs.Session{ - ID: testUUID(), - Node: "node1", - Behavior: structs.SessionKeysDelete, - }, - &structs.Session{ - ID: testUUID(), - Node: "node2", - Behavior: structs.SessionKeysRelease, - }, - &structs.Session{ - ID: testUUID(), - Node: "node3", - Behavior: structs.SessionKeysDelete, - }, - } - for i, session := range sessions { - if err := s.SessionCreate(uint64(4+i), session); err != nil { - t.Fatalf("err: %s", err) - } - } - - // List out all of the sessions - idx, sessionList, err := s.SessionList() - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } - if !reflect.DeepEqual(sessionList, sessions) { - t.Fatalf("bad: %#v", sessions) - } -} - -func TestStateStore_NodeSessions(t *testing.T) { - s := testStateStore(t) - - // Listing sessions with no results returns nil - idx, res, err := s.NodeSessions("node1") - if idx != 0 || res != nil || err != nil { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) - } - - // Create the nodes - testRegisterNode(t, s, 1, "node1") - testRegisterNode(t, s, 2, "node2") - - // Register some sessions with the nodes - sessions1 := structs.Sessions{ - &structs.Session{ - ID: testUUID(), - Node: "node1", - }, - &structs.Session{ - ID: testUUID(), - Node: "node1", - }, - } - sessions2 := []*structs.Session{ - &structs.Session{ - ID: testUUID(), - Node: "node2", - }, - &structs.Session{ - ID: testUUID(), - Node: "node2", - }, - } - for i, sess := range append(sessions1, sessions2...) { - if err := s.SessionCreate(uint64(3+i), sess); err != nil { - t.Fatalf("err: %s", err) - } - } - - // Query all of the sessions associated with a specific - // node in the state store. - idx, res, err = s.NodeSessions("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if len(res) != len(sessions1) { - t.Fatalf("bad: %#v", res) - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } - - idx, res, err = s.NodeSessions("node2") - if err != nil { - t.Fatalf("err: %s", err) - } - if len(res) != len(sessions2) { - t.Fatalf("bad: %#v", res) - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } -} - -func TestStateStore_SessionDestroy(t *testing.T) { - s := testStateStore(t) - - // Session destroy is idempotent and returns no error - // if the session doesn't exist. - if err := s.SessionDestroy(1, testUUID()); err != nil { - t.Fatalf("err: %s", err) - } - - // Ensure the index was not updated if nothing was destroyed. - if idx := s.maxIndex("sessions"); idx != 0 { - t.Fatalf("bad index: %d", idx) - } - - // Register a node. - testRegisterNode(t, s, 1, "node1") - - // Register a new session - sess := &structs.Session{ - ID: testUUID(), - Node: "node1", - } - if err := s.SessionCreate(2, sess); err != nil { - t.Fatalf("err: %s", err) - } - - // Destroy the session. - if err := s.SessionDestroy(3, sess.ID); err != nil { - t.Fatalf("err: %s", err) - } - - // Check that the index was updated - if idx := s.maxIndex("sessions"); idx != 3 { - t.Fatalf("bad index: %d", idx) - } - - // Make sure the session is really gone. - tx := s.db.Txn(false) - sessions, err := tx.Get("sessions", "id") - if err != nil || sessions.Next() != nil { - t.Fatalf("session should not exist") - } - tx.Abort() -} - -func TestStateStore_Session_Snapshot_Restore(t *testing.T) { - s := testStateStore(t) - - // Register some nodes and checks. - testRegisterNode(t, s, 1, "node1") - testRegisterNode(t, s, 2, "node2") - testRegisterNode(t, s, 3, "node3") - testRegisterCheck(t, s, 4, "node1", "", "check1", structs.HealthPassing) - - // Create some sessions in the state store. - session1 := testUUID() - sessions := structs.Sessions{ - &structs.Session{ - ID: session1, - Node: "node1", - Behavior: structs.SessionKeysDelete, - Checks: []types.CheckID{"check1"}, - }, - &structs.Session{ - ID: testUUID(), - Node: "node2", - Behavior: structs.SessionKeysRelease, - LockDelay: 10 * time.Second, - }, - &structs.Session{ - ID: testUUID(), - Node: "node3", - Behavior: structs.SessionKeysDelete, - TTL: "1.5s", - }, - } - for i, session := range sessions { - if err := s.SessionCreate(uint64(5+i), session); err != nil { - t.Fatalf("err: %s", err) - } - } - - // Snapshot the sessions. - snap := s.Snapshot() - defer snap.Close() - - // Alter the real state store. - if err := s.SessionDestroy(8, session1); err != nil { - t.Fatalf("err: %s", err) - } - - // Verify the snapshot. - if idx := snap.LastIndex(); idx != 7 { - t.Fatalf("bad index: %d", idx) - } - iter, err := snap.Sessions() - if err != nil { - t.Fatalf("err: %s", err) - } - var dump structs.Sessions - for session := iter.Next(); session != nil; session = iter.Next() { - sess := session.(*structs.Session) - dump = append(dump, sess) - - found := false - for i, _ := range sessions { - if sess.ID == sessions[i].ID { - if !reflect.DeepEqual(sess, sessions[i]) { - t.Fatalf("bad: %#v", sess) - } - found = true - } - } - if !found { - t.Fatalf("bad: %#v", sess) - } - } - - // Restore the sessions into a new state store. - func() { - s := testStateStore(t) - restore := s.Restore() - for _, session := range dump { - if err := restore.Session(session); err != nil { - t.Fatalf("err: %s", err) - } - } - restore.Commit() - - // Read the restored sessions back out and verify that they - // match. - idx, res, err := s.SessionList() - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 7 { - t.Fatalf("bad index: %d", idx) - } - for _, sess := range res { - found := false - for i, _ := range sessions { - if sess.ID == sessions[i].ID { - if !reflect.DeepEqual(sess, sessions[i]) { - t.Fatalf("bad: %#v", sess) - } - found = true - } - } - if !found { - t.Fatalf("bad: %#v", sess) - } - } - - // Check that the index was updated. - if idx := s.maxIndex("sessions"); idx != 7 { - t.Fatalf("bad index: %d", idx) - } - - // Manually verify that the session check mapping got restored. - tx := s.db.Txn(false) - defer tx.Abort() - - check, err := tx.First("session_checks", "session", session1) - if err != nil { - t.Fatalf("err: %s", err) - } - if check == nil { - t.Fatalf("missing session check") - } - expectCheck := &sessionCheck{ - Node: "node1", - CheckID: "check1", - Session: session1, - } - if actual := check.(*sessionCheck); !reflect.DeepEqual(actual, expectCheck) { - t.Fatalf("expected %#v, got: %#v", expectCheck, actual) - } - }() -} - -func TestStateStore_Session_Watches(t *testing.T) { - s := testStateStore(t) - - // Register a test node. - testRegisterNode(t, s, 1, "node1") - - // This just covers the basics. The session invalidation tests above - // cover the more nuanced multiple table watches. - session := testUUID() - verifyWatch(t, s.getTableWatch("sessions"), func() { - sess := &structs.Session{ - ID: session, - Node: "node1", - Behavior: structs.SessionKeysDelete, - } - if err := s.SessionCreate(2, sess); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("sessions"), func() { - if err := s.SessionDestroy(3, session); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("sessions"), func() { - restore := s.Restore() - sess := &structs.Session{ - ID: session, - Node: "node1", - Behavior: structs.SessionKeysDelete, - } - if err := restore.Session(sess); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - }) -} - -func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) { - s := testStateStore(t) - - // Set up our test environment. - if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - session := &structs.Session{ - ID: testUUID(), - Node: "foo", - } - if err := s.SessionCreate(14, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Delete the node and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("nodes"), func() { - if err := s.DeleteNode(15, "foo"); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) - - // Lookup by ID, should be nil. - idx, s2, err := s.SessionGet(session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if s2 != nil { - t.Fatalf("session should be invalidated") - } - if idx != 15 { - t.Fatalf("bad index: %d", idx) - } -} - -func TestStateStore_Session_Invalidate_DeleteService(t *testing.T) { - s := testStateStore(t) - - // Set up our test environment. - if err := s.EnsureNode(11, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - if err := s.EnsureService(12, "foo", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "api", - Name: "Can connect", - Status: structs.HealthPassing, - ServiceID: "api", - } - if err := s.EnsureCheck(13, check); err != nil { - t.Fatalf("err: %v", err) - } - session := &structs.Session{ - ID: testUUID(), - Node: "foo", - Checks: []types.CheckID{"api"}, - } - if err := s.SessionCreate(14, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Delete the service and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("services"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - if err := s.DeleteService(15, "foo", "api"); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) - }) - - // Lookup by ID, should be nil. - idx, s2, err := s.SessionGet(session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if s2 != nil { - t.Fatalf("session should be invalidated") - } - if idx != 15 { - t.Fatalf("bad index: %d", idx) - } -} - -func TestStateStore_Session_Invalidate_Critical_Check(t *testing.T) { - s := testStateStore(t) - - // Set up our test environment. - if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "bar", - Status: structs.HealthPassing, - } - if err := s.EnsureCheck(13, check); err != nil { - t.Fatalf("err: %v", err) - } - session := &structs.Session{ - ID: testUUID(), - Node: "foo", - Checks: []types.CheckID{"bar"}, - } - if err := s.SessionCreate(14, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Invalidate the check and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - check.Status = structs.HealthCritical - if err := s.EnsureCheck(15, check); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) - - // Lookup by ID, should be nil. - idx, s2, err := s.SessionGet(session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if s2 != nil { - t.Fatalf("session should be invalidated") - } - if idx != 15 { - t.Fatalf("bad index: %d", idx) - } -} - -func TestStateStore_Session_Invalidate_DeleteCheck(t *testing.T) { - s := testStateStore(t) - - // Set up our test environment. - if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "bar", - Status: structs.HealthPassing, - } - if err := s.EnsureCheck(13, check); err != nil { - t.Fatalf("err: %v", err) - } - session := &structs.Session{ - ID: testUUID(), - Node: "foo", - Checks: []types.CheckID{"bar"}, - } - if err := s.SessionCreate(14, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Delete the check and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("checks"), func() { - if err := s.DeleteCheck(15, "foo", "bar"); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) - - // Lookup by ID, should be nil. - idx, s2, err := s.SessionGet(session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if s2 != nil { - t.Fatalf("session should be invalidated") - } - if idx != 15 { - t.Fatalf("bad index: %d", idx) - } - - // Manually make sure the session checks mapping is clear. - tx := s.db.Txn(false) - mapping, err := tx.First("session_checks", "session", session.ID) - if err != nil { - t.Fatalf("err: %s", err) - } - if mapping != nil { - t.Fatalf("unexpected session check") - } - tx.Abort() -} - -func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) { - s := testStateStore(t) - - // Set up our test environment. - if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - session := &structs.Session{ - ID: testUUID(), - Node: "foo", - LockDelay: 50 * time.Millisecond, - } - if err := s.SessionCreate(4, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Lock a key with the session. - d := &structs.DirEntry{ - Key: "/foo", - Flags: 42, - Value: []byte("test"), - Session: session.ID, - } - ok, err := s.KVSLock(5, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("unexpected fail") - } - - // Delete the node and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.GetKVSWatch("/f"), func() { - if err := s.DeleteNode(6, "foo"); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) - }) - - // Lookup by ID, should be nil. - idx, s2, err := s.SessionGet(session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if s2 != nil { - t.Fatalf("session should be invalidated") - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } - - // Key should be unlocked. - idx, d2, err := s.KVSGet("/foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if d2.ModifyIndex != 6 { - t.Fatalf("bad index: %v", d2.ModifyIndex) - } - if d2.LockIndex != 1 { - t.Fatalf("bad: %v", *d2) - } - if d2.Session != "" { - t.Fatalf("bad: %v", *d2) - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } - - // Key should have a lock delay. - expires := s.KVSLockDelay("/foo") - if expires.Before(time.Now().Add(30 * time.Millisecond)) { - t.Fatalf("Bad: %v", expires) - } -} - -func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) { - s := testStateStore(t) - - // Set up our test environment. - if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - session := &structs.Session{ - ID: testUUID(), - Node: "foo", - LockDelay: 50 * time.Millisecond, - Behavior: structs.SessionKeysDelete, - } - if err := s.SessionCreate(4, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Lock a key with the session. - d := &structs.DirEntry{ - Key: "/bar", - Flags: 42, - Value: []byte("test"), - Session: session.ID, - } - ok, err := s.KVSLock(5, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("unexpected fail") - } - - // Delete the node and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("nodes"), func() { - verifyWatch(t, s.GetKVSWatch("/b"), func() { - if err := s.DeleteNode(6, "foo"); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) - }) - - // Lookup by ID, should be nil. - idx, s2, err := s.SessionGet(session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if s2 != nil { - t.Fatalf("session should be invalidated") - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } - - // Key should be deleted. - idx, d2, err := s.KVSGet("/bar") - if err != nil { - t.Fatalf("err: %s", err) - } - if d2 != nil { - t.Fatalf("unexpected deleted key") - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } - - // Key should have a lock delay. - expires := s.KVSLockDelay("/bar") - if expires.Before(time.Now().Add(30 * time.Millisecond)) { - t.Fatalf("Bad: %v", expires) - } -} - -func TestStateStore_Session_Invalidate_PreparedQuery_Delete(t *testing.T) { - s := testStateStore(t) - - // Set up our test environment. - testRegisterNode(t, s, 1, "foo") - testRegisterService(t, s, 2, "foo", "redis") - session := &structs.Session{ - ID: testUUID(), - Node: "foo", - } - if err := s.SessionCreate(3, session); err != nil { - t.Fatalf("err: %v", err) - } - query := &structs.PreparedQuery{ - ID: testUUID(), - Session: session.ID, - Service: structs.ServiceQuery{ - Service: "redis", - }, - } - if err := s.PreparedQuerySet(4, query); err != nil { - t.Fatalf("err: %s", err) - } - - // Invalidate the session and make sure the watches fire. - verifyWatch(t, s.getTableWatch("sessions"), func() { - verifyWatch(t, s.getTableWatch("prepared-queries"), func() { - if err := s.SessionDestroy(5, session.ID); err != nil { - t.Fatalf("err: %v", err) - } - }) - }) - - // Make sure the session is gone. - idx, s2, err := s.SessionGet(session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if s2 != nil { - t.Fatalf("session should be invalidated") - } - if idx != 5 { - t.Fatalf("bad index: %d", idx) - } - - // Make sure the query is gone and the index is updated. - idx, q2, err := s.PreparedQueryGet(query.ID) - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 5 { - t.Fatalf("bad index: %d", idx) - } - if q2 != nil { - t.Fatalf("bad: %v", q2) - } -} - -func TestStateStore_ACLSet_ACLGet(t *testing.T) { - s := testStateStore(t) - - // Querying ACLs with no results returns nil - idx, res, err := s.ACLGet("nope") - if idx != 0 || res != nil || err != nil { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) - } - - // Inserting an ACL with empty ID is disallowed - if err := s.ACLSet(1, &structs.ACL{}); err == nil { - t.Fatalf("expected %#v, got: %#v", ErrMissingACLID, err) - } - - // Index is not updated if nothing is saved - if idx := s.maxIndex("acls"); idx != 0 { - t.Fatalf("bad index: %d", idx) - } - - // Inserting valid ACL works - acl := &structs.ACL{ - ID: "acl1", - Name: "First ACL", - Type: structs.ACLTypeClient, - Rules: "rules1", - } - if err := s.ACLSet(1, acl); err != nil { - t.Fatalf("err: %s", err) - } - - // Check that the index was updated - if idx := s.maxIndex("acls"); idx != 1 { - t.Fatalf("bad index: %d", idx) - } - - // Retrieve the ACL again - idx, result, err := s.ACLGet("acl1") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 1 { - t.Fatalf("bad index: %d", idx) - } - - // Check that the ACL matches the result - expect := &structs.ACL{ - ID: "acl1", - Name: "First ACL", - Type: structs.ACLTypeClient, - Rules: "rules1", - RaftIndex: structs.RaftIndex{ - CreateIndex: 1, - ModifyIndex: 1, - }, - } - if !reflect.DeepEqual(result, expect) { - t.Fatalf("bad: %#v", result) - } - - // Update the ACL - acl = &structs.ACL{ - ID: "acl1", - Name: "First ACL", - Type: structs.ACLTypeClient, - Rules: "rules2", - } - if err := s.ACLSet(2, acl); err != nil { - t.Fatalf("err: %s", err) - } - - // Index was updated - if idx := s.maxIndex("acls"); idx != 2 { - t.Fatalf("bad: %d", idx) - } - - // ACL was updated and matches expected value - expect = &structs.ACL{ - ID: "acl1", - Name: "First ACL", - Type: structs.ACLTypeClient, - Rules: "rules2", - RaftIndex: structs.RaftIndex{ - CreateIndex: 1, - ModifyIndex: 2, - }, - } - if !reflect.DeepEqual(acl, expect) { - t.Fatalf("bad: %#v", acl) - } -} - -func TestStateStore_ACLList(t *testing.T) { - s := testStateStore(t) - - // Listing when no ACLs exist returns nil - idx, res, err := s.ACLList() - if idx != 0 || res != nil || err != nil { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) - } - - // Insert some ACLs - acls := structs.ACLs{ - &structs.ACL{ - ID: "acl1", - Type: structs.ACLTypeClient, - Rules: "rules1", - RaftIndex: structs.RaftIndex{ - CreateIndex: 1, - ModifyIndex: 1, - }, - }, - &structs.ACL{ - ID: "acl2", - Type: structs.ACLTypeClient, - Rules: "rules2", - RaftIndex: structs.RaftIndex{ - CreateIndex: 2, - ModifyIndex: 2, - }, - }, - } - for _, acl := range acls { - if err := s.ACLSet(acl.ModifyIndex, acl); err != nil { - t.Fatalf("err: %s", err) - } - } - - // Query the ACLs - idx, res, err = s.ACLList() - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 2 { - t.Fatalf("bad index: %d", idx) - } - - // Check that the result matches - if !reflect.DeepEqual(res, acls) { - t.Fatalf("bad: %#v", res) - } -} - -func TestStateStore_ACLDelete(t *testing.T) { - s := testStateStore(t) - - // Calling delete on an ACL which doesn't exist returns nil - if err := s.ACLDelete(1, "nope"); err != nil { - t.Fatalf("err: %s", err) - } - - // Index isn't updated if nothing is deleted - if idx := s.maxIndex("acls"); idx != 0 { - t.Fatalf("bad index: %d", idx) - } - - // Insert an ACL - if err := s.ACLSet(1, &structs.ACL{ID: "acl1"}); err != nil { - t.Fatalf("err: %s", err) - } - - // Delete the ACL and check that the index was updated - if err := s.ACLDelete(2, "acl1"); err != nil { - t.Fatalf("err: %s", err) - } - if idx := s.maxIndex("acls"); idx != 2 { - t.Fatalf("bad index: %d", idx) - } - - tx := s.db.Txn(false) - defer tx.Abort() - - // Check that the ACL was really deleted - result, err := tx.First("acls", "id", "acl1") - if err != nil { - t.Fatalf("err: %s", err) - } - if result != nil { - t.Fatalf("expected nil, got: %#v", result) - } -} - -func TestStateStore_ACL_Snapshot_Restore(t *testing.T) { - s := testStateStore(t) - - // Insert some ACLs. - acls := structs.ACLs{ - &structs.ACL{ - ID: "acl1", - Type: structs.ACLTypeClient, - Rules: "rules1", - RaftIndex: structs.RaftIndex{ - CreateIndex: 1, - ModifyIndex: 1, - }, - }, - &structs.ACL{ - ID: "acl2", - Type: structs.ACLTypeClient, - Rules: "rules2", - RaftIndex: structs.RaftIndex{ - CreateIndex: 2, - ModifyIndex: 2, - }, - }, - } - for _, acl := range acls { - if err := s.ACLSet(acl.ModifyIndex, acl); err != nil { - t.Fatalf("err: %s", err) - } - } - - // Snapshot the ACLs. - snap := s.Snapshot() - defer snap.Close() - - // Alter the real state store. - if err := s.ACLDelete(3, "acl1"); err != nil { - t.Fatalf("err: %s", err) - } - - // Verify the snapshot. - if idx := snap.LastIndex(); idx != 2 { - t.Fatalf("bad index: %d", idx) - } - iter, err := snap.ACLs() - if err != nil { - t.Fatalf("err: %s", err) - } - var dump structs.ACLs - for acl := iter.Next(); acl != nil; acl = iter.Next() { - dump = append(dump, acl.(*structs.ACL)) - } - if !reflect.DeepEqual(dump, acls) { - t.Fatalf("bad: %#v", dump) - } - - // Restore the values into a new state store. - func() { - s := testStateStore(t) - restore := s.Restore() - for _, acl := range dump { - if err := restore.ACL(acl); err != nil { - t.Fatalf("err: %s", err) - } - } - restore.Commit() - - // Read the restored ACLs back out and verify that they match. - idx, res, err := s.ACLList() - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 2 { - t.Fatalf("bad index: %d", idx) - } - if !reflect.DeepEqual(res, acls) { - t.Fatalf("bad: %#v", res) - } - - // Check that the index was updated. - if idx := s.maxIndex("acls"); idx != 2 { - t.Fatalf("bad index: %d", idx) - } - }() -} - -func TestStateStore_ACL_Watches(t *testing.T) { - s := testStateStore(t) - - // Call functions that update the acls table and make sure a watch fires - // each time. - verifyWatch(t, s.getTableWatch("acls"), func() { - if err := s.ACLSet(1, &structs.ACL{ID: "acl1"}); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("acls"), func() { - if err := s.ACLDelete(2, "acl1"); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("acls"), func() { - restore := s.Restore() - if err := restore.ACL(&structs.ACL{ID: "acl1"}); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - }) -} - -// generateRandomCoordinate creates a random coordinate. This mucks with the -// underlying structure directly, so it's not really useful for any particular -// position in the network, but it's a good payload to send through to make -// sure things come out the other side or get stored correctly. -func generateRandomCoordinate() *coordinate.Coordinate { - config := coordinate.DefaultConfig() - coord := coordinate.NewCoordinate(config) - for i := range coord.Vec { - coord.Vec[i] = rand.NormFloat64() - } - coord.Error = rand.NormFloat64() - coord.Adjustment = rand.NormFloat64() - return coord -} - -func TestStateStore_Coordinate_Updates(t *testing.T) { - s := testStateStore(t) - - // Make sure the coordinates list starts out empty, and that a query for - // a raw coordinate for a nonexistent node doesn't do anything bad. - idx, coords, err := s.Coordinates() - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 0 { - t.Fatalf("bad index: %d", idx) - } - if coords != nil { - t.Fatalf("bad: %#v", coords) - } - coord, err := s.CoordinateGetRaw("nope") - if err != nil { - t.Fatalf("err: %s", err) - } - if coord != nil { - t.Fatalf("bad: %#v", coord) - } - - // Make an update for nodes that don't exist and make sure they get - // ignored. - updates := structs.Coordinates{ - &structs.Coordinate{ - Node: "node1", - Coord: generateRandomCoordinate(), - }, - &structs.Coordinate{ - Node: "node2", - Coord: generateRandomCoordinate(), - }, - } - if err := s.CoordinateBatchUpdate(1, updates); err != nil { - t.Fatalf("err: %s", err) - } - - // Should still be empty, though applying an empty batch does bump - // the table index. - idx, coords, err = s.Coordinates() - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 1 { - t.Fatalf("bad index: %d", idx) - } - if coords != nil { - t.Fatalf("bad: %#v", coords) - } - - // Register the nodes then do the update again. - testRegisterNode(t, s, 1, "node1") - testRegisterNode(t, s, 2, "node2") - if err := s.CoordinateBatchUpdate(3, updates); err != nil { - t.Fatalf("err: %s", err) - } - - // Should go through now. - idx, coords, err = s.Coordinates() - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 3 { - t.Fatalf("bad index: %d", idx) - } - if !reflect.DeepEqual(coords, updates) { - t.Fatalf("bad: %#v", coords) - } - - // Also verify the raw coordinate interface. - for _, update := range updates { - coord, err := s.CoordinateGetRaw(update.Node) - if err != nil { - t.Fatalf("err: %s", err) - } - if !reflect.DeepEqual(coord, update.Coord) { - t.Fatalf("bad: %#v", coord) - } - } - - // Update the coordinate for one of the nodes. - updates[1].Coord = generateRandomCoordinate() - if err := s.CoordinateBatchUpdate(4, updates); err != nil { - t.Fatalf("err: %s", err) - } - - // Verify it got applied. - idx, coords, err = s.Coordinates() - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 4 { - t.Fatalf("bad index: %d", idx) - } - if !reflect.DeepEqual(coords, updates) { - t.Fatalf("bad: %#v", coords) - } - - // And check the raw coordinate version of the same thing. - for _, update := range updates { - coord, err := s.CoordinateGetRaw(update.Node) - if err != nil { - t.Fatalf("err: %s", err) - } - if !reflect.DeepEqual(coord, update.Coord) { - t.Fatalf("bad: %#v", coord) - } - } -} - -func TestStateStore_Coordinate_Cleanup(t *testing.T) { - s := testStateStore(t) - - // Register a node and update its coordinate. - testRegisterNode(t, s, 1, "node1") - updates := structs.Coordinates{ - &structs.Coordinate{ - Node: "node1", - Coord: generateRandomCoordinate(), - }, - } - if err := s.CoordinateBatchUpdate(2, updates); err != nil { - t.Fatalf("err: %s", err) - } - - // Make sure it's in there. - coord, err := s.CoordinateGetRaw("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if !reflect.DeepEqual(coord, updates[0].Coord) { - t.Fatalf("bad: %#v", coord) - } - - // Now delete the node. - if err := s.DeleteNode(3, "node1"); err != nil { - t.Fatalf("err: %s", err) - } - - // Make sure the coordinate is gone. - coord, err = s.CoordinateGetRaw("node1") - if err != nil { - t.Fatalf("err: %s", err) - } - if coord != nil { - t.Fatalf("bad: %#v", coord) - } - - // Make sure the index got updated. - idx, coords, err := s.Coordinates() - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 3 { - t.Fatalf("bad index: %d", idx) - } - if coords != nil { - t.Fatalf("bad: %#v", coords) - } -} - -func TestStateStore_Coordinate_Snapshot_Restore(t *testing.T) { - s := testStateStore(t) - - // Register two nodes and update their coordinates. - testRegisterNode(t, s, 1, "node1") - testRegisterNode(t, s, 2, "node2") - updates := structs.Coordinates{ - &structs.Coordinate{ - Node: "node1", - Coord: generateRandomCoordinate(), - }, - &structs.Coordinate{ - Node: "node2", - Coord: generateRandomCoordinate(), - }, - } - if err := s.CoordinateBatchUpdate(3, updates); err != nil { - t.Fatalf("err: %s", err) - } - - // Snapshot the coordinates. - snap := s.Snapshot() - defer snap.Close() - - // Alter the real state store. - trash := structs.Coordinates{ - &structs.Coordinate{ - Node: "node1", - Coord: generateRandomCoordinate(), - }, - &structs.Coordinate{ - Node: "node2", - Coord: generateRandomCoordinate(), - }, - } - if err := s.CoordinateBatchUpdate(4, trash); err != nil { - t.Fatalf("err: %s", err) - } - - // Verify the snapshot. - if idx := snap.LastIndex(); idx != 3 { - t.Fatalf("bad index: %d", idx) - } - iter, err := snap.Coordinates() - if err != nil { - t.Fatalf("err: %s", err) - } - var dump structs.Coordinates - for coord := iter.Next(); coord != nil; coord = iter.Next() { - dump = append(dump, coord.(*structs.Coordinate)) - } - if !reflect.DeepEqual(dump, updates) { - t.Fatalf("bad: %#v", dump) - } - - // Restore the values into a new state store. - func() { - s := testStateStore(t) - restore := s.Restore() - if err := restore.Coordinates(5, dump); err != nil { - t.Fatalf("err: %s", err) - } - restore.Commit() - - // Read the restored coordinates back out and verify that they match. - idx, res, err := s.Coordinates() - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 5 { - t.Fatalf("bad index: %d", idx) - } - if !reflect.DeepEqual(res, updates) { - t.Fatalf("bad: %#v", res) - } - - // Check that the index was updated (note that it got passed - // in during the restore). - if idx := s.maxIndex("coordinates"); idx != 5 { - t.Fatalf("bad index: %d", idx) - } - }() - -} - -func TestStateStore_Coordinate_Watches(t *testing.T) { - s := testStateStore(t) - - testRegisterNode(t, s, 1, "node1") - - // Call functions that update the coordinates table and make sure a watch fires - // each time. - verifyWatch(t, s.getTableWatch("coordinates"), func() { - updates := structs.Coordinates{ - &structs.Coordinate{ - Node: "node1", - Coord: generateRandomCoordinate(), - }, - } - if err := s.CoordinateBatchUpdate(2, updates); err != nil { - t.Fatalf("err: %s", err) - } - }) - verifyWatch(t, s.getTableWatch("coordinates"), func() { - if err := s.DeleteNode(3, "node1"); err != nil { - t.Fatalf("err: %s", err) - } - }) -}