From 009fd7d9f57d968c5cbbd5f40c967df24f83d096 Mon Sep 17 00:00:00 2001 From: James Phillips Date: Sun, 20 Sep 2015 01:36:39 -0700 Subject: [PATCH] Integrates new state store for ACLs. --- consul/acl.go | 4 +- consul/acl_endpoint.go | 20 ++-- consul/acl_endpoint_test.go | 10 +- consul/fsm.go | 27 ++++- consul/fsm_test.go | 10 +- consul/leader.go | 6 +- consul/rpc.go | 70 ++++++++++++ consul/state/notify.go | 55 +++++++++ consul/state/notify_test.go | 72 ++++++++++++ consul/state/state_store.go | 187 +++++++++++++++++++++++++------ consul/state/state_store_test.go | 65 ++++++++++- consul/state/watch.go | 35 ++++++ consul/state_store.go | 137 +--------------------- consul/state_store_test.go | 186 ------------------------------ 14 files changed, 492 insertions(+), 392 deletions(-) create mode 100644 consul/state/notify.go create mode 100644 consul/state/notify_test.go create mode 100644 consul/state/watch.go diff --git a/consul/acl.go b/consul/acl.go index 9d2c1d94b0..f669f3f867 100644 --- a/consul/acl.go +++ b/consul/acl.go @@ -51,8 +51,8 @@ type aclCacheEntry struct { // aclFault is used to fault in the rules for an ACL if we take a miss func (s *Server) aclFault(id string) (string, string, error) { defer metrics.MeasureSince([]string{"consul", "acl", "fault"}, time.Now()) - state := s.fsm.State() - _, acl, err := state.ACLGet(id) + state := s.fsm.StateNew() + acl, err := state.ACLGet(id) if err != nil { return "", "", err } diff --git a/consul/acl_endpoint.go b/consul/acl_endpoint.go index f3c162b989..51fd1272c0 100644 --- a/consul/acl_endpoint.go +++ b/consul/acl_endpoint.go @@ -60,10 +60,10 @@ func (a *ACL) Apply(args *structs.ACLRequest, reply *string) error { // deterministic. Once the entry is in the log, the state update MUST // be deterministic or the followers will not converge. if args.ACL.ID == "" { - state := a.srv.fsm.State() + state := a.srv.fsm.StateNew() for { args.ACL.ID = generateUUID() - _, acl, err := state.ACLGet(args.ACL.ID) + acl, err := state.ACLGet(args.ACL.ID) if err != nil { a.srv.logger.Printf("[ERR] consul.acl: ACL lookup failed: %v", err) return err @@ -120,14 +120,14 @@ func (a *ACL) Get(args *structs.ACLSpecificRequest, } // Get the local state - state := a.srv.fsm.State() - return a.srv.blockingRPC(&args.QueryOptions, + state := a.srv.fsm.StateNew() + return a.srv.blockingRPCNew(&args.QueryOptions, &reply.QueryMeta, - state.QueryTables("ACLGet"), + state.GetWatchManager("acls"), func() error { - index, acl, err := state.ACLGet(args.ACL) - reply.Index = index + acl, err := state.ACLGet(args.ACL) if acl != nil { + reply.Index = acl.ModifyIndex reply.ACLs = structs.ACLs{acl} } else { reply.ACLs = nil @@ -191,10 +191,10 @@ func (a *ACL) List(args *structs.DCSpecificRequest, } // Get the local state - state := a.srv.fsm.State() - return a.srv.blockingRPC(&args.QueryOptions, + state := a.srv.fsm.StateNew() + return a.srv.blockingRPCNew(&args.QueryOptions, &reply.QueryMeta, - state.QueryTables("ACLList"), + state.GetWatchManager("acls"), func() error { var err error reply.Index, reply.ACLs, err = state.ACLList() diff --git a/consul/acl_endpoint_test.go b/consul/acl_endpoint_test.go index f162c90c0c..3ba0078e29 100644 --- a/consul/acl_endpoint_test.go +++ b/consul/acl_endpoint_test.go @@ -39,8 +39,8 @@ func TestACLEndpoint_Apply(t *testing.T) { id := out // Verify - state := s1.fsm.State() - _, s, err := state.ACLGet(out) + state := s1.fsm.StateNew() + s, err := state.ACLGet(out) if err != nil { t.Fatalf("err: %v", err) } @@ -62,7 +62,7 @@ func TestACLEndpoint_Apply(t *testing.T) { } // Verify - _, s, err = state.ACLGet(id) + s, err = state.ACLGet(id) if err != nil { t.Fatalf("err: %v", err) } @@ -180,8 +180,8 @@ func TestACLEndpoint_Apply_CustomID(t *testing.T) { } // Verify - state := s1.fsm.State() - _, s, err := state.ACLGet(out) + state := s1.fsm.StateNew() + s, err := state.ACLGet(out) if err != nil { t.Fatalf("err: %v", err) } diff --git a/consul/fsm.go b/consul/fsm.go index 58f07d4552..1e0fb73f8d 100644 --- a/consul/fsm.go +++ b/consul/fsm.go @@ -9,6 +9,7 @@ import ( "time" "github.com/armon/go-metrics" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/go-msgpack/codec" "github.com/hashicorp/raft" @@ -24,6 +25,7 @@ type consulFSM struct { logOutput io.Writer logger *log.Logger path string + stateNew *state.StateStore state *StateStore gc *TombstoneGC } @@ -32,7 +34,8 @@ type consulFSM struct { // state in a way that can be accessed concurrently with operations // that may modify the live state. type consulSnapshot struct { - state *StateSnapshot + state *StateSnapshot + stateNew *state.StateSnapshot } // snapshotHeader is the first entry in our snapshot @@ -44,6 +47,12 @@ type snapshotHeader struct { // NewFSMPath is used to construct a new FSM with a blank state func NewFSM(gc *TombstoneGC, path string, logOutput io.Writer) (*consulFSM, error) { + // Create the state store. + stateNew, err := state.NewStateStore(logOutput) + if err != nil { + return nil, err + } + // Create a temporary path for the state store tmpPath, err := ioutil.TempDir(path, "state") if err != nil { @@ -60,6 +69,7 @@ func NewFSM(gc *TombstoneGC, path string, logOutput io.Writer) (*consulFSM, erro logOutput: logOutput, logger: log.New(logOutput, "", log.LstdFlags), path: path, + stateNew: stateNew, state: state, gc: gc, } @@ -71,6 +81,11 @@ func (c *consulFSM) Close() error { return c.state.Close() } +// TODO(slackpad) +func (c *consulFSM) StateNew() *state.StateStore { + return c.stateNew +} + // State is used to return a handle to the current state func (c *consulFSM) State() *StateStore { return c.state @@ -234,13 +249,13 @@ func (c *consulFSM) applyACLOperation(buf []byte, index uint64) interface{} { defer metrics.MeasureSince([]string{"consul", "fsm", "acl", string(req.Op)}, time.Now()) switch req.Op { case structs.ACLForceSet, structs.ACLSet: - if err := c.state.ACLSet(index, &req.ACL); err != nil { + if err := c.stateNew.ACLSet(index, &req.ACL); err != nil { return err } else { return req.ACL.ID } case structs.ACLDelete: - return c.state.ACLDelete(index, req.ACL.ID) + return c.stateNew.ACLDelete(index, req.ACL.ID) default: c.logger.Printf("[WARN] consul.fsm: Invalid ACL operation '%s'", req.Op) return fmt.Errorf("Invalid ACL operation '%s'", req.Op) @@ -272,7 +287,7 @@ func (c *consulFSM) Snapshot() (raft.FSMSnapshot, error) { if err != nil { return nil, err } - return &consulSnapshot{snap}, nil + return &consulSnapshot{snap, c.stateNew.Snapshot()}, nil } func (c *consulFSM) Restore(old io.ReadCloser) error { @@ -344,7 +359,7 @@ func (c *consulFSM) Restore(old io.ReadCloser) error { if err := dec.Decode(&req); err != nil { return err } - if err := c.state.ACLRestore(&req); err != nil { + if err := c.stateNew.ACLRestore(&req); err != nil { return err } @@ -467,7 +482,7 @@ func (s *consulSnapshot) persistSessions(sink raft.SnapshotSink, func (s *consulSnapshot) persistACLs(sink raft.SnapshotSink, encoder *codec.Encoder) error { - acls, err := s.state.ACLList() + acls, err := s.stateNew.ACLList() if err != nil { return err } diff --git a/consul/fsm_test.go b/consul/fsm_test.go index 7795a9191a..d6d9925082 100644 --- a/consul/fsm_test.go +++ b/consul/fsm_test.go @@ -361,7 +361,7 @@ func TestFSM_SnapshotRestore(t *testing.T) { session := &structs.Session{ID: generateUUID(), Node: "foo"} fsm.state.SessionCreate(9, session) acl := &structs.ACL{ID: generateUUID(), Name: "User Token"} - fsm.state.ACLSet(10, acl) + fsm.stateNew.ACLSet(10, acl) fsm.state.KVSSet(11, &structs.DirEntry{ Key: "/remove", @@ -448,14 +448,14 @@ func TestFSM_SnapshotRestore(t *testing.T) { } // Verify ACL is restored - idx, a, err := fsm2.state.ACLGet(acl.ID) + a, err := fsm2.stateNew.ACLGet(acl.ID) if err != nil { t.Fatalf("err: %v", err) } if a.Name != "User Token" { t.Fatalf("bad: %v", a) } - if idx <= 1 { + if a.ModifyIndex <= 1 { t.Fatalf("bad index: %d", idx) } @@ -971,7 +971,7 @@ func TestFSM_ACL_Set_Delete(t *testing.T) { // Get the ACL id := resp.(string) - _, acl, err := fsm.state.ACLGet(id) + acl, err := fsm.stateNew.ACLGet(id) if err != nil { t.Fatalf("err: %v", err) } @@ -1007,7 +1007,7 @@ func TestFSM_ACL_Set_Delete(t *testing.T) { t.Fatalf("resp: %v", resp) } - _, acl, err = fsm.state.ACLGet(id) + acl, err = fsm.stateNew.ACLGet(id) if err != nil { t.Fatalf("err: %v", err) } diff --git a/consul/leader.go b/consul/leader.go index 67be5bb591..f8bb07cd9e 100644 --- a/consul/leader.go +++ b/consul/leader.go @@ -182,8 +182,8 @@ func (s *Server) initializeACL() error { s.aclAuthCache.Purge() // Look for the anonymous token - state := s.fsm.State() - _, acl, err := state.ACLGet(anonymousToken) + state := s.fsm.StateNew() + acl, err := state.ACLGet(anonymousToken) if err != nil { return fmt.Errorf("failed to get anonymous token: %v", err) } @@ -212,7 +212,7 @@ func (s *Server) initializeACL() error { } // Look for the master token - _, acl, err = state.ACLGet(master) + acl, err = state.ACLGet(master) if err != nil { return fmt.Errorf("failed to get master token: %v", err) } diff --git a/consul/rpc.go b/consul/rpc.go index 292f71949e..6860ddaeab 100644 --- a/consul/rpc.go +++ b/consul/rpc.go @@ -10,6 +10,7 @@ import ( "time" "github.com/armon/go-metrics" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/yamux" @@ -397,6 +398,75 @@ RUN_QUERY: return err } +// TODO(slackpad) +func (s *Server) blockingRPCNew(queryOpts *structs.QueryOptions, queryMeta *structs.QueryMeta, + watch state.WatchManager, run func() error) error { + var timeout *time.Timer + var notifyCh chan struct{} + + // Fast path right to the non-blocking query. + if queryOpts.MinQueryIndex == 0 { + goto RUN_QUERY + } + + // Make sure a watch manager was given if we were asked to block. + if watch == nil { + panic("no watch manager given for blocking query") + } + + // Restrict the max query time, and ensure there is always one. + if queryOpts.MaxQueryTime > maxQueryTime { + queryOpts.MaxQueryTime = maxQueryTime + } else if queryOpts.MaxQueryTime <= 0 { + queryOpts.MaxQueryTime = defaultQueryTime + } + + // Apply a small amount of jitter to the request. + queryOpts.MaxQueryTime += randomStagger(queryOpts.MaxQueryTime / jitterFraction) + + // Setup a query timeout. + timeout = time.NewTimer(queryOpts.MaxQueryTime) + + // Setup the notify channel. + notifyCh = make(chan struct{}, 1) + + // Ensure we tear down any watches on return. + defer func() { + timeout.Stop() + watch.Stop(notifyCh) + }() + +REGISTER_NOTIFY: + // Register the notification channel. This may be done multiple times if + // we haven't reached the target wait index. + watch.Start(notifyCh) + +RUN_QUERY: + // Update the query metadata. + s.setQueryMeta(queryMeta) + + // If the read must be consistent we verify that we are still the leader. + if queryOpts.RequireConsistent { + if err := s.consistentRead(); err != nil { + return err + } + } + + // Run the query. + metrics.IncrCounter([]string{"consul", "rpc", "query"}, 1) + err := run() + + // Check for minimum query time. + if err == nil && queryMeta.Index > 0 && queryMeta.Index <= queryOpts.MinQueryIndex { + select { + case <-notifyCh: + goto REGISTER_NOTIFY + case <-timeout.C: + } + } + return err +} + // setQueryMeta is used to populate the QueryMeta data for an RPC call func (s *Server) setQueryMeta(m *structs.QueryMeta) { if s.IsLeader() { diff --git a/consul/state/notify.go b/consul/state/notify.go new file mode 100644 index 0000000000..3b991a656a --- /dev/null +++ b/consul/state/notify.go @@ -0,0 +1,55 @@ +package state + +import ( + "sync" +) + +// NotifyGroup is used to allow a simple notification mechanism. +// Channels can be marked as waiting, and when notify is invoked, +// all the waiting channels get a message and are cleared from the +// notify list. +type NotifyGroup struct { + l sync.Mutex + notify map[chan struct{}]struct{} +} + +// Notify will do a non-blocking send to all waiting channels, and +// clear the notify list +func (n *NotifyGroup) Notify() { + n.l.Lock() + defer n.l.Unlock() + for ch, _ := range n.notify { + select { + case ch <- struct{}{}: + default: + } + } + n.notify = nil +} + +// Wait adds a channel to the notify group +func (n *NotifyGroup) Wait(ch chan struct{}) { + n.l.Lock() + defer n.l.Unlock() + if n.notify == nil { + n.notify = make(map[chan struct{}]struct{}) + } + n.notify[ch] = struct{}{} +} + +// Clear removes a channel from the notify group +func (n *NotifyGroup) Clear(ch chan struct{}) { + n.l.Lock() + defer n.l.Unlock() + if n.notify == nil { + return + } + delete(n.notify, ch) +} + +// WaitCh allocates a channel that is subscribed to notifications +func (n *NotifyGroup) WaitCh() chan struct{} { + ch := make(chan struct{}, 1) + n.Wait(ch) + return ch +} diff --git a/consul/state/notify_test.go b/consul/state/notify_test.go new file mode 100644 index 0000000000..34c14f46db --- /dev/null +++ b/consul/state/notify_test.go @@ -0,0 +1,72 @@ +package state + +import ( + "testing" +) + +func TestNotifyGroup(t *testing.T) { + grp := &NotifyGroup{} + + ch1 := grp.WaitCh() + ch2 := grp.WaitCh() + + select { + case <-ch1: + t.Fatalf("should block") + default: + } + select { + case <-ch2: + t.Fatalf("should block") + default: + } + + grp.Notify() + + select { + case <-ch1: + default: + t.Fatalf("should not block") + } + select { + case <-ch2: + default: + t.Fatalf("should not block") + } + + // Should be unregistered + ch3 := grp.WaitCh() + grp.Notify() + + select { + case <-ch1: + t.Fatalf("should block") + default: + } + select { + case <-ch2: + t.Fatalf("should block") + default: + } + select { + case <-ch3: + default: + t.Fatalf("should not block") + } +} + +func TestNotifyGroup_Clear(t *testing.T) { + grp := &NotifyGroup{} + + ch1 := grp.WaitCh() + grp.Clear(ch1) + + grp.Notify() + + // Should not get anything + select { + case <-ch1: + t.Fatalf("should not get message") + default: + } +} diff --git a/consul/state/state_store.go b/consul/state/state_store.go index 8731f051ef..7de24c13fa 100644 --- a/consul/state/state_store.go +++ b/consul/state/state_store.go @@ -34,8 +34,17 @@ var ( // pairs and more. The DB is entirely in-memory and is constructed // from the Raft log through the FSM. type StateStore struct { - logger *log.Logger // TODO(slackpad) - Delete if unused! - db *memdb.MemDB + logger *log.Logger // TODO(slackpad) - Delete if unused! + schema *memdb.DBSchema + db *memdb.MemDB + watches map[string]WatchManager +} + +// StateSnapshot is used to provide a point-in-time snapshot. It +// works by starting a read transaction against the whole state store. +type StateSnapshot struct { + tx *memdb.Txn + lastIndex uint64 } // IndexEntry keeps a record of the last index per-table. @@ -56,26 +65,69 @@ type sessionCheck struct { // NewStateStore creates a new in-memory state storage layer. func NewStateStore(logOutput io.Writer) (*StateStore, error) { - // Create the in-memory DB - db, err := memdb.NewMemDB(stateStoreSchema()) + // Create the in-memory DB. + schema := stateStoreSchema() + db, err := memdb.NewMemDB(schema) if err != nil { return nil, fmt.Errorf("Failed setting up state store: %s", err) } - // Create and return the state store + // Build up the watch managers. + watches, err := newWatchManagers(schema) + if err != nil { + return nil, fmt.Errorf("Failed to build watch managers: %s", err) + } + + // Create and return the state store. s := &StateStore{ - logger: log.New(logOutput, "", log.LstdFlags), - db: db, + logger: log.New(logOutput, "", log.LstdFlags), + schema: schema, + db: db, + watches: watches, } return s, nil } +// Snapshot is used to create a point-in-time snapshot of the entire db. +func (s *StateStore) Snapshot() *StateSnapshot { + tx := s.db.Txn(false) + + var tables []string + for table, _ := range s.schema.Tables { + tables = append(tables, table) + } + idx := maxIndexTxn(tx, tables...) + + return &StateSnapshot{tx, idx} +} + +// LastIndex returns that last index that affects the snapshotted data. +func (s *StateSnapshot) LastIndex() uint64 { + return s.lastIndex +} + +// Close performs cleanup of a state snapshot. +func (s *StateSnapshot) Close() { + s.tx.Abort() +} + +// ACLList is used to pull all the ACLs from the snapshot. +func (s *StateSnapshot) ACLList() ([]*structs.ACL, error) { + _, ret, err := aclListTxn(s.tx) + return ret, err +} + // maxIndex is a helper used to retrieve the highest known index // amongst a set of tables in the db. func (s *StateStore) maxIndex(tables ...string) uint64 { tx := s.db.Txn(false) defer tx.Abort() + return maxIndexTxn(tx, tables...) +} +// maxIndexTxn is a helper used to retrieve the highest known index +// amongst a set of tables in the db. +func maxIndexTxn(tx *memdb.Txn, tables ...string) uint64 { var lindex uint64 for _, table := range tables { ti, err := tx.First("index", "id", table) @@ -89,13 +141,51 @@ func (s *StateStore) maxIndex(tables ...string) uint64 { return lindex } +// indexUpdateMaxTxn is used when restoring entries and sets the table's index to +// the given idx only if it's greater than the current index. +func indexUpdateMaxTxn(tx *memdb.Txn, idx uint64, table string) error { + raw, err := tx.First("index", "id", table) + if err != nil { + return fmt.Errorf("failed to retrieve existing index: %s", err) + } + + if raw == nil { + return fmt.Errorf("missing index for table %s", table) + } + + entry, ok := raw.(*IndexEntry) + if !ok { + return fmt.Errorf("unexpected index type for table %s", table) + } + + if idx > entry.Value { + if err := tx.Insert("index", &IndexEntry{table, idx}); err != nil { + return fmt.Errorf("failed updating index %s", err) + } + } + + return nil +} + +// getWatchManager returns a watch manager for the given set of tables. The +// order of the tables is not important. +func (s *StateStore) GetWatchManager(tables ...string) WatchManager { + if len(tables) == 1 { + if manager, ok := s.watches[tables[0]]; ok { + return manager + } + } + + panic(fmt.Sprintf("Unknown watch manager(s): %v", tables)) +} + // EnsureNode is used to upsert node registration or modification. func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error { tx := s.db.Txn(true) defer tx.Abort() // Call the node upsert - if err := s.ensureNodeTxn(idx, node, tx); err != nil { + if err := ensureNodeTxn(tx, idx, node); err != nil { return err } @@ -106,7 +196,7 @@ func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error { // ensureNodeTxn is the inner function called to actually create a node // registration or modify an existing one in the state store. It allows // passing in a memdb transaction so it may be part of a larger txn. -func (s *StateStore) ensureNodeTxn(idx uint64, node *structs.Node, tx *memdb.Txn) error { +func ensureNodeTxn(tx *memdb.Txn, idx uint64, node *structs.Node) error { // Check for an existing node existing, err := tx.First("nodes", "id", node.Node) if err != nil { @@ -179,7 +269,7 @@ func (s *StateStore) DeleteNode(idx uint64, nodeID string) error { defer tx.Abort() // Call the node deletion. - if err := s.deleteNodeTxn(idx, nodeID, tx); err != nil { + if err := deleteNodeTxn(tx, idx, nodeID); err != nil { return err } @@ -189,7 +279,7 @@ func (s *StateStore) DeleteNode(idx uint64, nodeID string) error { // deleteNodeTxn is the inner method used for removing a node from // the store within a given transaction. -func (s *StateStore) deleteNodeTxn(idx uint64, nodeID string, tx *memdb.Txn) error { +func deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeID string) error { // Look up the node node, err := tx.First("nodes", "id", nodeID) if err != nil { @@ -206,7 +296,7 @@ func (s *StateStore) deleteNodeTxn(idx uint64, nodeID string, tx *memdb.Txn) err } for service := services.Next(); service != nil; service = services.Next() { svc := service.(*structs.ServiceNode) - if err := s.deleteServiceTxn(idx, nodeID, svc.ServiceID, tx); err != nil { + if err := deleteServiceTxn(tx, idx, nodeID, svc.ServiceID); err != nil { return err } } @@ -218,7 +308,7 @@ func (s *StateStore) deleteNodeTxn(idx uint64, nodeID string, tx *memdb.Txn) err } for check := checks.Next(); check != nil; check = checks.Next() { chk := check.(*structs.HealthCheck) - if err := s.deleteCheckTxn(idx, nodeID, chk.CheckID, tx); err != nil { + if err := deleteCheckTxn(tx, idx, nodeID, chk.CheckID); err != nil { return err } } @@ -242,7 +332,7 @@ func (s *StateStore) EnsureService(idx uint64, node string, svc *structs.NodeSer defer tx.Abort() // Call the service registration upsert - if err := s.ensureServiceTxn(idx, node, svc, tx); err != nil { + if err := ensureServiceTxn(tx, idx, node, svc); err != nil { return err } @@ -252,7 +342,7 @@ func (s *StateStore) EnsureService(idx uint64, node string, svc *structs.NodeSer // ensureServiceTxn is used to upsert a service registration within an // existing memdb transaction. -func (s *StateStore) ensureServiceTxn(idx uint64, node string, svc *structs.NodeService, tx *memdb.Txn) error { +func ensureServiceTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeService) error { // Check for existing service existing, err := tx.First("services", "id", node, svc.Service) if err != nil { @@ -358,7 +448,7 @@ func (s *StateStore) DeleteService(idx uint64, nodeID, serviceID string) error { defer tx.Abort() // Call the service deletion - if err := s.deleteServiceTxn(idx, nodeID, serviceID, tx); err != nil { + if err := deleteServiceTxn(tx, idx, nodeID, serviceID); err != nil { return err } @@ -368,7 +458,7 @@ func (s *StateStore) DeleteService(idx uint64, nodeID, serviceID string) error { // deleteServiceTxn is the inner method called to remove a service // registration within an existing transaction. -func (s *StateStore) deleteServiceTxn(idx uint64, nodeID, serviceID string, tx *memdb.Txn) error { +func deleteServiceTxn(tx *memdb.Txn, idx uint64, nodeID, serviceID string) error { // Look up the service service, err := tx.First("services", "id", nodeID, serviceID) if err != nil { @@ -411,7 +501,7 @@ func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error { defer tx.Abort() // Call the check registration - if err := s.ensureCheckTxn(idx, hc, tx); err != nil { + if err := ensureCheckTxn(tx, idx, hc); err != nil { return err } @@ -422,7 +512,7 @@ func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error { // ensureCheckTransaction is used as the inner method to handle inserting // a health check into the state store. It ensures safety against inserting // checks with no matching node or service. -func (s *StateStore) ensureCheckTxn(idx uint64, hc *structs.HealthCheck, tx *memdb.Txn) error { +func ensureCheckTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthCheck) error { // Check if we have an existing health check existing, err := tx.First("checks", "id", hc.Node, hc.CheckID) if err != nil { @@ -541,7 +631,7 @@ func (s *StateStore) DeleteCheck(idx uint64, node, id string) error { defer tx.Abort() // Call the check deletion - if err := s.deleteCheckTxn(idx, node, id, tx); err != nil { + if err := deleteCheckTxn(tx, idx, node, id); err != nil { return err } @@ -551,7 +641,7 @@ func (s *StateStore) DeleteCheck(idx uint64, node, id string) error { // deleteCheckTxn is the inner method used to call a health // check deletion within an existing transaction. -func (s *StateStore) deleteCheckTxn(idx uint64, node, id string, tx *memdb.Txn) error { +func deleteCheckTxn(tx *memdb.Txn, idx uint64, node, id string) error { // Try to retrieve the existing health check check, err := tx.First("checks", "id", node, id) if err != nil { @@ -743,14 +833,12 @@ func (s *StateStore) parseNodes( func (s *StateStore) KVSSet(idx uint64, entry *structs.DirEntry) error { tx := s.db.Txn(true) defer tx.Abort() - return s.kvsSetTxn(idx, entry, tx) + return kvsSetTxn(tx, idx, entry) } // kvsSetTxn is used to insert or update a key/value pair in the state // store. It is the inner method used and handles only the actual storage. -func (s *StateStore) kvsSetTxn( - idx uint64, entry *structs.DirEntry, - tx *memdb.Txn) error { +func kvsSetTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) error { // Retrieve an existing KV pair existing, err := tx.First("kvs", "id", entry.Key) @@ -878,7 +966,7 @@ func (s *StateStore) KVSDelete(idx uint64, key string) error { defer tx.Abort() // Perform the actual delete - if err := s.kvsDeleteTxn(idx, key, tx); err != nil { + if err := kvsDeleteTxn(tx, idx, key); err != nil { return err } @@ -888,7 +976,7 @@ func (s *StateStore) KVSDelete(idx uint64, key string) error { // kvsDeleteTxn is the inner method used to perform the actual deletion // of a key/value pair within an existing transaction. -func (s *StateStore) kvsDeleteTxn(idx uint64, key string, tx *memdb.Txn) error { +func kvsDeleteTxn(tx *memdb.Txn, idx uint64, key string) error { // Look up the entry in the state store entry, err := tx.First("kvs", "id", key) if err != nil { @@ -931,7 +1019,7 @@ func (s *StateStore) KVSDeleteCAS(idx, cidx uint64, key string) (bool, error) { } // Call the actual deletion if the above passed - if err := s.kvsDeleteTxn(idx, key, tx); err != nil { + if err := kvsDeleteTxn(tx, idx, key); err != nil { return false, err } @@ -967,7 +1055,7 @@ func (s *StateStore) KVSSetCAS(idx uint64, entry *structs.DirEntry) (bool, error } // If we made it this far, we should perform the set. - return true, s.kvsSetTxn(idx, entry, tx) + return true, kvsSetTxn(tx, idx, entry) } // KVSDeleteTree is used to do a recursive delete on a key prefix @@ -1011,7 +1099,7 @@ func (s *StateStore) SessionCreate(idx uint64, sess *structs.Session) error { defer tx.Abort() // Call the session creation - if err := s.sessionCreateTxn(idx, sess, tx); err != nil { + if err := sessionCreateTxn(tx, idx, sess); err != nil { return err } @@ -1022,7 +1110,7 @@ func (s *StateStore) SessionCreate(idx uint64, sess *structs.Session) error { // sessionCreateTxn is the inner method used for creating session entries in // an open transaction. Any health checks registered with the session will be // checked for failing status. Returns any error encountered. -func (s *StateStore) sessionCreateTxn(idx uint64, sess *structs.Session, tx *memdb.Txn) error { +func sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Session) error { // Check that we have a session ID if sess.ID == "" { return ErrMissingSessionID @@ -1172,7 +1260,7 @@ func (s *StateStore) SessionDestroy(idx uint64, sessionID string) error { defer tx.Abort() // Call the session deletion - if err := s.sessionDestroyTxn(idx, sessionID, tx); err != nil { + if err := sessionDestroyTxn(tx, idx, sessionID); err != nil { return err } @@ -1182,7 +1270,7 @@ func (s *StateStore) SessionDestroy(idx uint64, sessionID string) error { // sessionDestroyTxn is the inner method, which is used to do the actual // session deletion and handle session invalidation, watch triggers, etc. -func (s *StateStore) sessionDestroyTxn(idx uint64, sessionID string, tx *memdb.Txn) error { +func sessionDestroyTxn(tx *memdb.Txn, idx uint64, sessionID string) error { // Look up the session sess, err := tx.First("sessions", "id", sessionID) if err != nil { @@ -1211,17 +1299,18 @@ func (s *StateStore) ACLSet(idx uint64, acl *structs.ACL) error { defer tx.Abort() // Call set on the ACL - if err := s.aclSetTxn(idx, acl, tx); err != nil { + if err := aclSetTxn(tx, idx, acl); err != nil { return err } + tx.Defer(func() { s.GetWatchManager("acls").Notify() }) tx.Commit() return nil } // aclSetTxn is the inner method used to insert an ACL rule with the // proper indexes into the state store. -func (s *StateStore) aclSetTxn(idx uint64, acl *structs.ACL, tx *memdb.Txn) error { +func aclSetTxn(tx *memdb.Txn, idx uint64, acl *structs.ACL) error { // Check that the ID is set if acl.ID == "" { return ErrMissingACLID @@ -1272,7 +1361,11 @@ func (s *StateStore) ACLGet(aclID string) (*structs.ACL, error) { func (s *StateStore) ACLList() (uint64, []*structs.ACL, error) { tx := s.db.Txn(false) defer tx.Abort() + return aclListTxn(tx) +} +// aclListTxn is used to list out all of the ACLs in the state store. +func aclListTxn(tx *memdb.Txn) (uint64, []*structs.ACL, error) { // Query all of the ACLs in the state store acls, err := tx.Get("acls", "id") if err != nil { @@ -1301,17 +1394,18 @@ func (s *StateStore) ACLDelete(idx uint64, aclID string) error { defer tx.Abort() // Call the ACL delete - if err := s.aclDeleteTxn(idx, aclID, tx); err != nil { + if err := aclDeleteTxn(tx, idx, aclID); err != nil { return err } + tx.Defer(func() { s.GetWatchManager("acls").Notify() }) tx.Commit() return nil } // aclDeleteTxn is used to delete an ACL from the state store within // an existing transaction. -func (s *StateStore) aclDeleteTxn(idx uint64, aclID string, tx *memdb.Txn) error { +func aclDeleteTxn(tx *memdb.Txn, idx uint64, aclID string) error { // Look up the existing ACL acl, err := tx.First("acls", "id", aclID) if err != nil { @@ -1330,3 +1424,22 @@ func (s *StateStore) aclDeleteTxn(idx uint64, aclID string, tx *memdb.Txn) error } return nil } + +// ACLRestore is used when restoring from a snapshot. For general inserts, use +// ACLSet. +func (s *StateStore) ACLRestore(acl *structs.ACL) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := tx.Insert("acls", acl); err != nil { + return fmt.Errorf("failed restoring acl: %s", err) + } + + if err := indexUpdateMaxTxn(tx, acl.ModifyIndex, "acls"); err != nil { + return err + } + + tx.Defer(func() { s.GetWatchManager("acls").Notify() }) + tx.Commit() + return nil +} diff --git a/consul/state/state_store_test.go b/consul/state/state_store_test.go index 64d611dabe..2eea4fb7a7 100644 --- a/consul/state/state_store_test.go +++ b/consul/state/state_store_test.go @@ -6,6 +6,7 @@ import ( "reflect" "strings" "testing" + "time" "github.com/hashicorp/consul/consul/structs" ) @@ -31,7 +32,7 @@ func testRegisterNode(t *testing.T, s *StateStore, idx uint64, nodeID string) { defer tx.Abort() n, err := tx.First("nodes", "id", nodeID) if err != nil { - t.Fatalf("err: %s", err, n) + t.Fatalf("err: %s", err) } if result, ok := n.(*structs.Node); !ok || result.Node != nodeID { t.Fatalf("bad node: %#v", result) @@ -107,14 +108,33 @@ func testSetKey(t *testing.T, s *StateStore, idx uint64, key, value string) { func TestStateStore_maxIndex(t *testing.T) { s := testStateStore(t) + testRegisterNode(t, s, 0, "foo") testRegisterNode(t, s, 1, "bar") testRegisterService(t, s, 2, "foo", "consul") + if max := s.maxIndex("nodes", "services"); max != 2 { t.Fatalf("bad max: %d", max) } } +func TestStateStore_indexUpdateMaxTxn(t *testing.T) { + s := testStateStore(t) + + testRegisterNode(t, s, 0, "foo") + testRegisterNode(t, s, 1, "bar") + + tx := s.db.Txn(true) + if err := indexUpdateMaxTxn(tx, 3, "nodes"); err != nil { + t.Fatalf("err: %s", err) + } + tx.Commit() + + if max := s.maxIndex("nodes"); max != 3 { + t.Fatalf("bad max: %d", max) + } +} + func TestStateStore_EnsureNode(t *testing.T) { s := testStateStore(t) @@ -1415,7 +1435,7 @@ func TestStateStore_SessionCreate_GetSession(t *testing.T) { t.Fatalf("err: %s", err) } if idx := s.maxIndex("sessions"); idx != 2 { - t.Fatalf("bad index: %d", err) + t.Fatalf("bad index: %s", err) } // Retrieve the session again @@ -1814,3 +1834,44 @@ func TestStateStore_ACLDelete(t *testing.T) { t.Fatalf("expected nil, got: %#v", result) } } + +func TestStateStore_ACL_Watches(t *testing.T) { + s := testStateStore(t) + ch := make(chan struct{}) + + s.GetWatchManager("acls").Start(ch) + go func() { + if err := s.ACLSet(1, &structs.ACL{ID: "acl1"}); err != nil { + t.Fatalf("err: %s", err) + } + }() + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("watch was not notified") + } + + s.GetWatchManager("acls").Start(ch) + go func() { + if err := s.ACLDelete(2, "acl1"); err != nil { + t.Fatalf("err: %s", err) + } + }() + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("watch was not notified") + } + + s.GetWatchManager("acls").Start(ch) + go func() { + if err := s.ACLRestore(&structs.ACL{ID: "acl1"}); err != nil { + t.Fatalf("err: %s", err) + } + }() + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("watch was not notified") + } +} diff --git a/consul/state/watch.go b/consul/state/watch.go new file mode 100644 index 0000000000..304c42ea28 --- /dev/null +++ b/consul/state/watch.go @@ -0,0 +1,35 @@ +package state + +import ( + "github.com/hashicorp/go-memdb" +) + +type WatchManager interface { + Start(notifyCh chan struct{}) + Stop(notifyCh chan struct{}) + Notify() +} + +type FullTableWatch struct { + notify NotifyGroup +} + +func (w *FullTableWatch) Start(notifyCh chan struct{}) { + w.notify.Wait(notifyCh) +} + +func (w *FullTableWatch) Stop(notifyCh chan struct{}) { + w.notify.Clear(notifyCh) +} + +func (w *FullTableWatch) Notify() { + w.notify.Notify() +} + +func newWatchManagers(schema *memdb.DBSchema) (map[string]WatchManager, error) { + watches := make(map[string]WatchManager) + for table, _ := range schema.Tables { + watches[table] = &FullTableWatch{} + } + return watches, nil +} diff --git a/consul/state_store.go b/consul/state_store.go index 5a0855aa80..8284885c5e 100644 --- a/consul/state_store.go +++ b/consul/state_store.go @@ -24,7 +24,6 @@ const ( dbTombstone = "tombstones" dbSessions = "sessions" dbSessionChecks = "sessionChecks" - dbACLs = "acls" dbMaxMapSize32bit uint64 = 128 * 1024 * 1024 // 128MB maximum size dbMaxMapSize64bit uint64 = 32 * 1024 * 1024 * 1024 // 32GB maximum size dbMaxReaders uint = 4096 // 4K, default is 126 @@ -59,7 +58,6 @@ type StateStore struct { tombstoneTable *MDBTable sessionTable *MDBTable sessionCheckTable *MDBTable - aclTable *MDBTable tables MDBTables watch map[*MDBTable]*NotifyGroup queryTables map[string]MDBTables @@ -361,27 +359,9 @@ func (s *StateStore) initialize() error { }, } - s.aclTable = &MDBTable{ - Name: dbACLs, - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"ID"}, - }, - }, - Decoder: func(buf []byte) interface{} { - out := new(structs.ACL) - if err := structs.Decode(buf, out); err != nil { - panic(err) - } - return out - }, - } - // Store the set of tables s.tables = []*MDBTable{s.nodeTable, s.serviceTable, s.checkTable, - s.kvsTable, s.tombstoneTable, s.sessionTable, s.sessionCheckTable, - s.aclTable} + s.kvsTable, s.tombstoneTable, s.sessionTable, s.sessionCheckTable} for _, table := range s.tables { table.Env = s.env table.Encoder = encoder @@ -408,8 +388,6 @@ func (s *StateStore) initialize() error { "SessionGet": MDBTables{s.sessionTable}, "SessionList": MDBTables{s.sessionTable}, "NodeSessions": MDBTables{s.sessionTable}, - "ACLGet": MDBTables{s.aclTable}, - "ACLList": MDBTables{s.aclTable}, } return nil } @@ -1945,109 +1923,6 @@ func (s *StateStore) deleteLocks(index uint64, tx *MDBTxn, return nil } -// ACLSet is used to create or update an ACL entry -func (s *StateStore) ACLSet(index uint64, acl *structs.ACL) error { - // Check for an ID - if acl.ID == "" { - return fmt.Errorf("Missing ACL ID") - } - - // Start a new txn - tx, err := s.tables.StartTxn(false) - if err != nil { - return err - } - defer tx.Abort() - - // Look for the existing node - res, err := s.aclTable.GetTxn(tx, "id", acl.ID) - if err != nil { - return err - } - - switch len(res) { - case 0: - acl.CreateIndex = index - acl.ModifyIndex = index - case 1: - exist := res[0].(*structs.ACL) - acl.CreateIndex = exist.CreateIndex - acl.ModifyIndex = index - default: - panic(fmt.Errorf("Duplicate ACL definition. Internal error")) - } - - // Insert the ACL - if err := s.aclTable.InsertTxn(tx, acl); err != nil { - return err - } - - // Trigger the update notifications - if err := s.aclTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.aclTable].Notify() }) - return tx.Commit() -} - -// ACLRestore is used to restore an ACL. It should only be used when -// doing a restore, otherwise ACLSet should be used. -func (s *StateStore) ACLRestore(acl *structs.ACL) error { - // Start a new txn - tx, err := s.aclTable.StartTxn(false, nil) - if err != nil { - return err - } - defer tx.Abort() - - if err := s.aclTable.InsertTxn(tx, acl); err != nil { - return err - } - if err := s.aclTable.SetMaxLastIndexTxn(tx, acl.ModifyIndex); err != nil { - return err - } - return tx.Commit() -} - -// ACLGet is used to get an ACL by ID -func (s *StateStore) ACLGet(id string) (uint64, *structs.ACL, error) { - idx, res, err := s.aclTable.Get("id", id) - var d *structs.ACL - if len(res) > 0 { - d = res[0].(*structs.ACL) - } - return idx, d, err -} - -// ACLList is used to list all the acls -func (s *StateStore) ACLList() (uint64, []*structs.ACL, error) { - idx, res, err := s.aclTable.Get("id") - out := make([]*structs.ACL, len(res)) - for i, raw := range res { - out[i] = raw.(*structs.ACL) - } - return idx, out, err -} - -// ACLDelete is used to remove an ACL -func (s *StateStore) ACLDelete(index uint64, id string) error { - tx, err := s.tables.StartTxn(false) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - if n, err := s.aclTable.DeleteTxn(tx, "id", id); err != nil { - return err - } else if n > 0 { - if err := s.aclTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.aclTable].Notify() }) - } - return tx.Commit() -} - // Snapshot is used to create a point in time snapshot func (s *StateStore) Snapshot() (*StateSnapshot, error) { // Begin a new txn on all tables @@ -2128,13 +2003,3 @@ func (s *StateSnapshot) SessionList() ([]*structs.Session, error) { } return out, err } - -// ACLList is used to list all of the ACLs -func (s *StateSnapshot) ACLList() ([]*structs.ACL, error) { - res, err := s.store.aclTable.GetTxn(s.tx, "id") - out := make([]*structs.ACL, len(res)) - for i, raw := range res { - out[i] = raw.(*structs.ACL) - } - return out, err -} diff --git a/consul/state_store_test.go b/consul/state_store_test.go index c1501887e7..dcd2f3d19c 100644 --- a/consul/state_store_test.go +++ b/consul/state_store_test.go @@ -762,24 +762,6 @@ func TestStoreSnapshot(t *testing.T) { t.Fatalf("err: %v", err) } - a1 := &structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - } - if err := store.ACLSet(21, a1); err != nil { - t.Fatalf("err: %v", err) - } - - a2 := &structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - } - if err := store.ACLSet(22, a2); err != nil { - t.Fatalf("err: %v", err) - } - // Take a snapshot snap, err := store.Snapshot() if err != nil { @@ -884,15 +866,6 @@ func TestStoreSnapshot(t *testing.T) { t.Fatalf("Wrong number of sessions with TTL") } - // Check for an acl - acls, err := snap.ACLList() - if err != nil { - t.Fatalf("err: %v", err) - } - if len(acls) != 2 { - t.Fatalf("missing acls") - } - // Make some changes! if err := store.EnsureService(23, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8000}); err != nil { t.Fatalf("err: %v", err) @@ -918,11 +891,6 @@ func TestStoreSnapshot(t *testing.T) { t.Fatalf("err: %v", err) } - // Nuke an ACL - if err := store.ACLDelete(29, a1.ID); err != nil { - t.Fatalf("err: %v", err) - } - // Check snapshot has old values nodes = snap.Nodes() if len(nodes) != 2 { @@ -1003,15 +971,6 @@ func TestStoreSnapshot(t *testing.T) { if len(sessions) != 3 { t.Fatalf("missing sessions") } - - // Check for an acl - acls, err = snap.ACLList() - if err != nil { - t.Fatalf("err: %v", err) - } - if len(acls) != 2 { - t.Fatalf("missing acls") - } } func TestEnsureCheck(t *testing.T) { @@ -2880,148 +2839,3 @@ func TestSessionInvalidate_KeyDelete(t *testing.T) { t.Fatalf("Bad: %v", expires) } } - -func TestACLSet_Get(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - idx, out, err := store.ACLGet("1234") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 0 { - t.Fatalf("bad: %v", idx) - } - if out != nil { - t.Fatalf("bad: %v", out) - } - - a := &structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - Rules: "", - } - if err := store.ACLSet(50, a); err != nil { - t.Fatalf("err: %v", err) - } - if a.CreateIndex != 50 { - t.Fatalf("Bad: %v", a) - } - if a.ModifyIndex != 50 { - t.Fatalf("Bad: %v", a) - } - if a.ID == "" { - t.Fatalf("Bad: %v", a) - } - - idx, out, err = store.ACLGet(a.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 50 { - t.Fatalf("bad: %v", idx) - } - if !reflect.DeepEqual(out, a) { - t.Fatalf("bad: %v", out) - } - - // Update - a.Rules = "foo bar baz" - if err := store.ACLSet(52, a); err != nil { - t.Fatalf("err: %v", err) - } - if a.CreateIndex != 50 { - t.Fatalf("Bad: %v", a) - } - if a.ModifyIndex != 52 { - t.Fatalf("Bad: %v", a) - } - - idx, out, err = store.ACLGet(a.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 52 { - t.Fatalf("bad: %v", idx) - } - if !reflect.DeepEqual(out, a) { - t.Fatalf("bad: %v", out) - } -} - -func TestACLDelete(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - a := &structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - Rules: "", - } - if err := store.ACLSet(50, a); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.ACLDelete(52, a.ID); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.ACLDelete(53, a.ID); err != nil { - t.Fatalf("err: %v", err) - } - - idx, out, err := store.ACLGet(a.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 52 { - t.Fatalf("bad: %v", idx) - } - if out != nil { - t.Fatalf("bad: %v", out) - } -} - -func TestACLList(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - a1 := &structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - } - if err := store.ACLSet(50, a1); err != nil { - t.Fatalf("err: %v", err) - } - - a2 := &structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - } - if err := store.ACLSet(51, a2); err != nil { - t.Fatalf("err: %v", err) - } - - idx, out, err := store.ACLList() - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 51 { - t.Fatalf("bad: %v", idx) - } - if len(out) != 2 { - t.Fatalf("bad: %v", out) - } -}