Integrates new state store for ACLs.

This commit is contained in:
James Phillips 2015-09-20 01:36:39 -07:00
parent d57431e300
commit 009fd7d9f5
14 changed files with 492 additions and 392 deletions

View File

@ -51,8 +51,8 @@ type aclCacheEntry struct {
// aclFault is used to fault in the rules for an ACL if we take a miss
func (s *Server) aclFault(id string) (string, string, error) {
defer metrics.MeasureSince([]string{"consul", "acl", "fault"}, time.Now())
state := s.fsm.State()
_, acl, err := state.ACLGet(id)
state := s.fsm.StateNew()
acl, err := state.ACLGet(id)
if err != nil {
return "", "", err
}

View File

@ -60,10 +60,10 @@ func (a *ACL) Apply(args *structs.ACLRequest, reply *string) error {
// deterministic. Once the entry is in the log, the state update MUST
// be deterministic or the followers will not converge.
if args.ACL.ID == "" {
state := a.srv.fsm.State()
state := a.srv.fsm.StateNew()
for {
args.ACL.ID = generateUUID()
_, acl, err := state.ACLGet(args.ACL.ID)
acl, err := state.ACLGet(args.ACL.ID)
if err != nil {
a.srv.logger.Printf("[ERR] consul.acl: ACL lookup failed: %v", err)
return err
@ -120,14 +120,14 @@ func (a *ACL) Get(args *structs.ACLSpecificRequest,
}
// Get the local state
state := a.srv.fsm.State()
return a.srv.blockingRPC(&args.QueryOptions,
state := a.srv.fsm.StateNew()
return a.srv.blockingRPCNew(&args.QueryOptions,
&reply.QueryMeta,
state.QueryTables("ACLGet"),
state.GetWatchManager("acls"),
func() error {
index, acl, err := state.ACLGet(args.ACL)
reply.Index = index
acl, err := state.ACLGet(args.ACL)
if acl != nil {
reply.Index = acl.ModifyIndex
reply.ACLs = structs.ACLs{acl}
} else {
reply.ACLs = nil
@ -191,10 +191,10 @@ func (a *ACL) List(args *structs.DCSpecificRequest,
}
// Get the local state
state := a.srv.fsm.State()
return a.srv.blockingRPC(&args.QueryOptions,
state := a.srv.fsm.StateNew()
return a.srv.blockingRPCNew(&args.QueryOptions,
&reply.QueryMeta,
state.QueryTables("ACLList"),
state.GetWatchManager("acls"),
func() error {
var err error
reply.Index, reply.ACLs, err = state.ACLList()

View File

@ -39,8 +39,8 @@ func TestACLEndpoint_Apply(t *testing.T) {
id := out
// Verify
state := s1.fsm.State()
_, s, err := state.ACLGet(out)
state := s1.fsm.StateNew()
s, err := state.ACLGet(out)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -62,7 +62,7 @@ func TestACLEndpoint_Apply(t *testing.T) {
}
// Verify
_, s, err = state.ACLGet(id)
s, err = state.ACLGet(id)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -180,8 +180,8 @@ func TestACLEndpoint_Apply_CustomID(t *testing.T) {
}
// Verify
state := s1.fsm.State()
_, s, err := state.ACLGet(out)
state := s1.fsm.StateNew()
s, err := state.ACLGet(out)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -9,6 +9,7 @@ import (
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/consul/state"
"github.com/hashicorp/consul/consul/structs"
"github.com/hashicorp/go-msgpack/codec"
"github.com/hashicorp/raft"
@ -24,6 +25,7 @@ type consulFSM struct {
logOutput io.Writer
logger *log.Logger
path string
stateNew *state.StateStore
state *StateStore
gc *TombstoneGC
}
@ -33,6 +35,7 @@ type consulFSM struct {
// that may modify the live state.
type consulSnapshot struct {
state *StateSnapshot
stateNew *state.StateSnapshot
}
// snapshotHeader is the first entry in our snapshot
@ -44,6 +47,12 @@ type snapshotHeader struct {
// NewFSMPath is used to construct a new FSM with a blank state
func NewFSM(gc *TombstoneGC, path string, logOutput io.Writer) (*consulFSM, error) {
// Create the state store.
stateNew, err := state.NewStateStore(logOutput)
if err != nil {
return nil, err
}
// Create a temporary path for the state store
tmpPath, err := ioutil.TempDir(path, "state")
if err != nil {
@ -60,6 +69,7 @@ func NewFSM(gc *TombstoneGC, path string, logOutput io.Writer) (*consulFSM, erro
logOutput: logOutput,
logger: log.New(logOutput, "", log.LstdFlags),
path: path,
stateNew: stateNew,
state: state,
gc: gc,
}
@ -71,6 +81,11 @@ func (c *consulFSM) Close() error {
return c.state.Close()
}
// TODO(slackpad)
func (c *consulFSM) StateNew() *state.StateStore {
return c.stateNew
}
// State is used to return a handle to the current state
func (c *consulFSM) State() *StateStore {
return c.state
@ -234,13 +249,13 @@ func (c *consulFSM) applyACLOperation(buf []byte, index uint64) interface{} {
defer metrics.MeasureSince([]string{"consul", "fsm", "acl", string(req.Op)}, time.Now())
switch req.Op {
case structs.ACLForceSet, structs.ACLSet:
if err := c.state.ACLSet(index, &req.ACL); err != nil {
if err := c.stateNew.ACLSet(index, &req.ACL); err != nil {
return err
} else {
return req.ACL.ID
}
case structs.ACLDelete:
return c.state.ACLDelete(index, req.ACL.ID)
return c.stateNew.ACLDelete(index, req.ACL.ID)
default:
c.logger.Printf("[WARN] consul.fsm: Invalid ACL operation '%s'", req.Op)
return fmt.Errorf("Invalid ACL operation '%s'", req.Op)
@ -272,7 +287,7 @@ func (c *consulFSM) Snapshot() (raft.FSMSnapshot, error) {
if err != nil {
return nil, err
}
return &consulSnapshot{snap}, nil
return &consulSnapshot{snap, c.stateNew.Snapshot()}, nil
}
func (c *consulFSM) Restore(old io.ReadCloser) error {
@ -344,7 +359,7 @@ func (c *consulFSM) Restore(old io.ReadCloser) error {
if err := dec.Decode(&req); err != nil {
return err
}
if err := c.state.ACLRestore(&req); err != nil {
if err := c.stateNew.ACLRestore(&req); err != nil {
return err
}
@ -467,7 +482,7 @@ func (s *consulSnapshot) persistSessions(sink raft.SnapshotSink,
func (s *consulSnapshot) persistACLs(sink raft.SnapshotSink,
encoder *codec.Encoder) error {
acls, err := s.state.ACLList()
acls, err := s.stateNew.ACLList()
if err != nil {
return err
}

View File

@ -361,7 +361,7 @@ func TestFSM_SnapshotRestore(t *testing.T) {
session := &structs.Session{ID: generateUUID(), Node: "foo"}
fsm.state.SessionCreate(9, session)
acl := &structs.ACL{ID: generateUUID(), Name: "User Token"}
fsm.state.ACLSet(10, acl)
fsm.stateNew.ACLSet(10, acl)
fsm.state.KVSSet(11, &structs.DirEntry{
Key: "/remove",
@ -448,14 +448,14 @@ func TestFSM_SnapshotRestore(t *testing.T) {
}
// Verify ACL is restored
idx, a, err := fsm2.state.ACLGet(acl.ID)
a, err := fsm2.stateNew.ACLGet(acl.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if a.Name != "User Token" {
t.Fatalf("bad: %v", a)
}
if idx <= 1 {
if a.ModifyIndex <= 1 {
t.Fatalf("bad index: %d", idx)
}
@ -971,7 +971,7 @@ func TestFSM_ACL_Set_Delete(t *testing.T) {
// Get the ACL
id := resp.(string)
_, acl, err := fsm.state.ACLGet(id)
acl, err := fsm.stateNew.ACLGet(id)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1007,7 +1007,7 @@ func TestFSM_ACL_Set_Delete(t *testing.T) {
t.Fatalf("resp: %v", resp)
}
_, acl, err = fsm.state.ACLGet(id)
acl, err = fsm.stateNew.ACLGet(id)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -182,8 +182,8 @@ func (s *Server) initializeACL() error {
s.aclAuthCache.Purge()
// Look for the anonymous token
state := s.fsm.State()
_, acl, err := state.ACLGet(anonymousToken)
state := s.fsm.StateNew()
acl, err := state.ACLGet(anonymousToken)
if err != nil {
return fmt.Errorf("failed to get anonymous token: %v", err)
}
@ -212,7 +212,7 @@ func (s *Server) initializeACL() error {
}
// Look for the master token
_, acl, err = state.ACLGet(master)
acl, err = state.ACLGet(master)
if err != nil {
return fmt.Errorf("failed to get master token: %v", err)
}

View File

@ -10,6 +10,7 @@ import (
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/consul/state"
"github.com/hashicorp/consul/consul/structs"
"github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/yamux"
@ -397,6 +398,75 @@ RUN_QUERY:
return err
}
// TODO(slackpad)
func (s *Server) blockingRPCNew(queryOpts *structs.QueryOptions, queryMeta *structs.QueryMeta,
watch state.WatchManager, run func() error) error {
var timeout *time.Timer
var notifyCh chan struct{}
// Fast path right to the non-blocking query.
if queryOpts.MinQueryIndex == 0 {
goto RUN_QUERY
}
// Make sure a watch manager was given if we were asked to block.
if watch == nil {
panic("no watch manager given for blocking query")
}
// Restrict the max query time, and ensure there is always one.
if queryOpts.MaxQueryTime > maxQueryTime {
queryOpts.MaxQueryTime = maxQueryTime
} else if queryOpts.MaxQueryTime <= 0 {
queryOpts.MaxQueryTime = defaultQueryTime
}
// Apply a small amount of jitter to the request.
queryOpts.MaxQueryTime += randomStagger(queryOpts.MaxQueryTime / jitterFraction)
// Setup a query timeout.
timeout = time.NewTimer(queryOpts.MaxQueryTime)
// Setup the notify channel.
notifyCh = make(chan struct{}, 1)
// Ensure we tear down any watches on return.
defer func() {
timeout.Stop()
watch.Stop(notifyCh)
}()
REGISTER_NOTIFY:
// Register the notification channel. This may be done multiple times if
// we haven't reached the target wait index.
watch.Start(notifyCh)
RUN_QUERY:
// Update the query metadata.
s.setQueryMeta(queryMeta)
// If the read must be consistent we verify that we are still the leader.
if queryOpts.RequireConsistent {
if err := s.consistentRead(); err != nil {
return err
}
}
// Run the query.
metrics.IncrCounter([]string{"consul", "rpc", "query"}, 1)
err := run()
// Check for minimum query time.
if err == nil && queryMeta.Index > 0 && queryMeta.Index <= queryOpts.MinQueryIndex {
select {
case <-notifyCh:
goto REGISTER_NOTIFY
case <-timeout.C:
}
}
return err
}
// setQueryMeta is used to populate the QueryMeta data for an RPC call
func (s *Server) setQueryMeta(m *structs.QueryMeta) {
if s.IsLeader() {

55
consul/state/notify.go Normal file
View File

@ -0,0 +1,55 @@
package state
import (
"sync"
)
// NotifyGroup is used to allow a simple notification mechanism.
// Channels can be marked as waiting, and when notify is invoked,
// all the waiting channels get a message and are cleared from the
// notify list.
type NotifyGroup struct {
l sync.Mutex
notify map[chan struct{}]struct{}
}
// Notify will do a non-blocking send to all waiting channels, and
// clear the notify list
func (n *NotifyGroup) Notify() {
n.l.Lock()
defer n.l.Unlock()
for ch, _ := range n.notify {
select {
case ch <- struct{}{}:
default:
}
}
n.notify = nil
}
// Wait adds a channel to the notify group
func (n *NotifyGroup) Wait(ch chan struct{}) {
n.l.Lock()
defer n.l.Unlock()
if n.notify == nil {
n.notify = make(map[chan struct{}]struct{})
}
n.notify[ch] = struct{}{}
}
// Clear removes a channel from the notify group
func (n *NotifyGroup) Clear(ch chan struct{}) {
n.l.Lock()
defer n.l.Unlock()
if n.notify == nil {
return
}
delete(n.notify, ch)
}
// WaitCh allocates a channel that is subscribed to notifications
func (n *NotifyGroup) WaitCh() chan struct{} {
ch := make(chan struct{}, 1)
n.Wait(ch)
return ch
}

View File

@ -0,0 +1,72 @@
package state
import (
"testing"
)
func TestNotifyGroup(t *testing.T) {
grp := &NotifyGroup{}
ch1 := grp.WaitCh()
ch2 := grp.WaitCh()
select {
case <-ch1:
t.Fatalf("should block")
default:
}
select {
case <-ch2:
t.Fatalf("should block")
default:
}
grp.Notify()
select {
case <-ch1:
default:
t.Fatalf("should not block")
}
select {
case <-ch2:
default:
t.Fatalf("should not block")
}
// Should be unregistered
ch3 := grp.WaitCh()
grp.Notify()
select {
case <-ch1:
t.Fatalf("should block")
default:
}
select {
case <-ch2:
t.Fatalf("should block")
default:
}
select {
case <-ch3:
default:
t.Fatalf("should not block")
}
}
func TestNotifyGroup_Clear(t *testing.T) {
grp := &NotifyGroup{}
ch1 := grp.WaitCh()
grp.Clear(ch1)
grp.Notify()
// Should not get anything
select {
case <-ch1:
t.Fatalf("should not get message")
default:
}
}

View File

@ -35,7 +35,16 @@ var (
// from the Raft log through the FSM.
type StateStore struct {
logger *log.Logger // TODO(slackpad) - Delete if unused!
schema *memdb.DBSchema
db *memdb.MemDB
watches map[string]WatchManager
}
// StateSnapshot is used to provide a point-in-time snapshot. It
// works by starting a read transaction against the whole state store.
type StateSnapshot struct {
tx *memdb.Txn
lastIndex uint64
}
// IndexEntry keeps a record of the last index per-table.
@ -56,26 +65,69 @@ type sessionCheck struct {
// NewStateStore creates a new in-memory state storage layer.
func NewStateStore(logOutput io.Writer) (*StateStore, error) {
// Create the in-memory DB
db, err := memdb.NewMemDB(stateStoreSchema())
// Create the in-memory DB.
schema := stateStoreSchema()
db, err := memdb.NewMemDB(schema)
if err != nil {
return nil, fmt.Errorf("Failed setting up state store: %s", err)
}
// Create and return the state store
// Build up the watch managers.
watches, err := newWatchManagers(schema)
if err != nil {
return nil, fmt.Errorf("Failed to build watch managers: %s", err)
}
// Create and return the state store.
s := &StateStore{
logger: log.New(logOutput, "", log.LstdFlags),
schema: schema,
db: db,
watches: watches,
}
return s, nil
}
// Snapshot is used to create a point-in-time snapshot of the entire db.
func (s *StateStore) Snapshot() *StateSnapshot {
tx := s.db.Txn(false)
var tables []string
for table, _ := range s.schema.Tables {
tables = append(tables, table)
}
idx := maxIndexTxn(tx, tables...)
return &StateSnapshot{tx, idx}
}
// LastIndex returns that last index that affects the snapshotted data.
func (s *StateSnapshot) LastIndex() uint64 {
return s.lastIndex
}
// Close performs cleanup of a state snapshot.
func (s *StateSnapshot) Close() {
s.tx.Abort()
}
// ACLList is used to pull all the ACLs from the snapshot.
func (s *StateSnapshot) ACLList() ([]*structs.ACL, error) {
_, ret, err := aclListTxn(s.tx)
return ret, err
}
// maxIndex is a helper used to retrieve the highest known index
// amongst a set of tables in the db.
func (s *StateStore) maxIndex(tables ...string) uint64 {
tx := s.db.Txn(false)
defer tx.Abort()
return maxIndexTxn(tx, tables...)
}
// maxIndexTxn is a helper used to retrieve the highest known index
// amongst a set of tables in the db.
func maxIndexTxn(tx *memdb.Txn, tables ...string) uint64 {
var lindex uint64
for _, table := range tables {
ti, err := tx.First("index", "id", table)
@ -89,13 +141,51 @@ func (s *StateStore) maxIndex(tables ...string) uint64 {
return lindex
}
// indexUpdateMaxTxn is used when restoring entries and sets the table's index to
// the given idx only if it's greater than the current index.
func indexUpdateMaxTxn(tx *memdb.Txn, idx uint64, table string) error {
raw, err := tx.First("index", "id", table)
if err != nil {
return fmt.Errorf("failed to retrieve existing index: %s", err)
}
if raw == nil {
return fmt.Errorf("missing index for table %s", table)
}
entry, ok := raw.(*IndexEntry)
if !ok {
return fmt.Errorf("unexpected index type for table %s", table)
}
if idx > entry.Value {
if err := tx.Insert("index", &IndexEntry{table, idx}); err != nil {
return fmt.Errorf("failed updating index %s", err)
}
}
return nil
}
// getWatchManager returns a watch manager for the given set of tables. The
// order of the tables is not important.
func (s *StateStore) GetWatchManager(tables ...string) WatchManager {
if len(tables) == 1 {
if manager, ok := s.watches[tables[0]]; ok {
return manager
}
}
panic(fmt.Sprintf("Unknown watch manager(s): %v", tables))
}
// EnsureNode is used to upsert node registration or modification.
func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error {
tx := s.db.Txn(true)
defer tx.Abort()
// Call the node upsert
if err := s.ensureNodeTxn(idx, node, tx); err != nil {
if err := ensureNodeTxn(tx, idx, node); err != nil {
return err
}
@ -106,7 +196,7 @@ func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error {
// ensureNodeTxn is the inner function called to actually create a node
// registration or modify an existing one in the state store. It allows
// passing in a memdb transaction so it may be part of a larger txn.
func (s *StateStore) ensureNodeTxn(idx uint64, node *structs.Node, tx *memdb.Txn) error {
func ensureNodeTxn(tx *memdb.Txn, idx uint64, node *structs.Node) error {
// Check for an existing node
existing, err := tx.First("nodes", "id", node.Node)
if err != nil {
@ -179,7 +269,7 @@ func (s *StateStore) DeleteNode(idx uint64, nodeID string) error {
defer tx.Abort()
// Call the node deletion.
if err := s.deleteNodeTxn(idx, nodeID, tx); err != nil {
if err := deleteNodeTxn(tx, idx, nodeID); err != nil {
return err
}
@ -189,7 +279,7 @@ func (s *StateStore) DeleteNode(idx uint64, nodeID string) error {
// deleteNodeTxn is the inner method used for removing a node from
// the store within a given transaction.
func (s *StateStore) deleteNodeTxn(idx uint64, nodeID string, tx *memdb.Txn) error {
func deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeID string) error {
// Look up the node
node, err := tx.First("nodes", "id", nodeID)
if err != nil {
@ -206,7 +296,7 @@ func (s *StateStore) deleteNodeTxn(idx uint64, nodeID string, tx *memdb.Txn) err
}
for service := services.Next(); service != nil; service = services.Next() {
svc := service.(*structs.ServiceNode)
if err := s.deleteServiceTxn(idx, nodeID, svc.ServiceID, tx); err != nil {
if err := deleteServiceTxn(tx, idx, nodeID, svc.ServiceID); err != nil {
return err
}
}
@ -218,7 +308,7 @@ func (s *StateStore) deleteNodeTxn(idx uint64, nodeID string, tx *memdb.Txn) err
}
for check := checks.Next(); check != nil; check = checks.Next() {
chk := check.(*structs.HealthCheck)
if err := s.deleteCheckTxn(idx, nodeID, chk.CheckID, tx); err != nil {
if err := deleteCheckTxn(tx, idx, nodeID, chk.CheckID); err != nil {
return err
}
}
@ -242,7 +332,7 @@ func (s *StateStore) EnsureService(idx uint64, node string, svc *structs.NodeSer
defer tx.Abort()
// Call the service registration upsert
if err := s.ensureServiceTxn(idx, node, svc, tx); err != nil {
if err := ensureServiceTxn(tx, idx, node, svc); err != nil {
return err
}
@ -252,7 +342,7 @@ func (s *StateStore) EnsureService(idx uint64, node string, svc *structs.NodeSer
// ensureServiceTxn is used to upsert a service registration within an
// existing memdb transaction.
func (s *StateStore) ensureServiceTxn(idx uint64, node string, svc *structs.NodeService, tx *memdb.Txn) error {
func ensureServiceTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeService) error {
// Check for existing service
existing, err := tx.First("services", "id", node, svc.Service)
if err != nil {
@ -358,7 +448,7 @@ func (s *StateStore) DeleteService(idx uint64, nodeID, serviceID string) error {
defer tx.Abort()
// Call the service deletion
if err := s.deleteServiceTxn(idx, nodeID, serviceID, tx); err != nil {
if err := deleteServiceTxn(tx, idx, nodeID, serviceID); err != nil {
return err
}
@ -368,7 +458,7 @@ func (s *StateStore) DeleteService(idx uint64, nodeID, serviceID string) error {
// deleteServiceTxn is the inner method called to remove a service
// registration within an existing transaction.
func (s *StateStore) deleteServiceTxn(idx uint64, nodeID, serviceID string, tx *memdb.Txn) error {
func deleteServiceTxn(tx *memdb.Txn, idx uint64, nodeID, serviceID string) error {
// Look up the service
service, err := tx.First("services", "id", nodeID, serviceID)
if err != nil {
@ -411,7 +501,7 @@ func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error {
defer tx.Abort()
// Call the check registration
if err := s.ensureCheckTxn(idx, hc, tx); err != nil {
if err := ensureCheckTxn(tx, idx, hc); err != nil {
return err
}
@ -422,7 +512,7 @@ func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error {
// ensureCheckTransaction is used as the inner method to handle inserting
// a health check into the state store. It ensures safety against inserting
// checks with no matching node or service.
func (s *StateStore) ensureCheckTxn(idx uint64, hc *structs.HealthCheck, tx *memdb.Txn) error {
func ensureCheckTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthCheck) error {
// Check if we have an existing health check
existing, err := tx.First("checks", "id", hc.Node, hc.CheckID)
if err != nil {
@ -541,7 +631,7 @@ func (s *StateStore) DeleteCheck(idx uint64, node, id string) error {
defer tx.Abort()
// Call the check deletion
if err := s.deleteCheckTxn(idx, node, id, tx); err != nil {
if err := deleteCheckTxn(tx, idx, node, id); err != nil {
return err
}
@ -551,7 +641,7 @@ func (s *StateStore) DeleteCheck(idx uint64, node, id string) error {
// deleteCheckTxn is the inner method used to call a health
// check deletion within an existing transaction.
func (s *StateStore) deleteCheckTxn(idx uint64, node, id string, tx *memdb.Txn) error {
func deleteCheckTxn(tx *memdb.Txn, idx uint64, node, id string) error {
// Try to retrieve the existing health check
check, err := tx.First("checks", "id", node, id)
if err != nil {
@ -743,14 +833,12 @@ func (s *StateStore) parseNodes(
func (s *StateStore) KVSSet(idx uint64, entry *structs.DirEntry) error {
tx := s.db.Txn(true)
defer tx.Abort()
return s.kvsSetTxn(idx, entry, tx)
return kvsSetTxn(tx, idx, entry)
}
// kvsSetTxn is used to insert or update a key/value pair in the state
// store. It is the inner method used and handles only the actual storage.
func (s *StateStore) kvsSetTxn(
idx uint64, entry *structs.DirEntry,
tx *memdb.Txn) error {
func kvsSetTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) error {
// Retrieve an existing KV pair
existing, err := tx.First("kvs", "id", entry.Key)
@ -878,7 +966,7 @@ func (s *StateStore) KVSDelete(idx uint64, key string) error {
defer tx.Abort()
// Perform the actual delete
if err := s.kvsDeleteTxn(idx, key, tx); err != nil {
if err := kvsDeleteTxn(tx, idx, key); err != nil {
return err
}
@ -888,7 +976,7 @@ func (s *StateStore) KVSDelete(idx uint64, key string) error {
// kvsDeleteTxn is the inner method used to perform the actual deletion
// of a key/value pair within an existing transaction.
func (s *StateStore) kvsDeleteTxn(idx uint64, key string, tx *memdb.Txn) error {
func kvsDeleteTxn(tx *memdb.Txn, idx uint64, key string) error {
// Look up the entry in the state store
entry, err := tx.First("kvs", "id", key)
if err != nil {
@ -931,7 +1019,7 @@ func (s *StateStore) KVSDeleteCAS(idx, cidx uint64, key string) (bool, error) {
}
// Call the actual deletion if the above passed
if err := s.kvsDeleteTxn(idx, key, tx); err != nil {
if err := kvsDeleteTxn(tx, idx, key); err != nil {
return false, err
}
@ -967,7 +1055,7 @@ func (s *StateStore) KVSSetCAS(idx uint64, entry *structs.DirEntry) (bool, error
}
// If we made it this far, we should perform the set.
return true, s.kvsSetTxn(idx, entry, tx)
return true, kvsSetTxn(tx, idx, entry)
}
// KVSDeleteTree is used to do a recursive delete on a key prefix
@ -1011,7 +1099,7 @@ func (s *StateStore) SessionCreate(idx uint64, sess *structs.Session) error {
defer tx.Abort()
// Call the session creation
if err := s.sessionCreateTxn(idx, sess, tx); err != nil {
if err := sessionCreateTxn(tx, idx, sess); err != nil {
return err
}
@ -1022,7 +1110,7 @@ func (s *StateStore) SessionCreate(idx uint64, sess *structs.Session) error {
// sessionCreateTxn is the inner method used for creating session entries in
// an open transaction. Any health checks registered with the session will be
// checked for failing status. Returns any error encountered.
func (s *StateStore) sessionCreateTxn(idx uint64, sess *structs.Session, tx *memdb.Txn) error {
func sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Session) error {
// Check that we have a session ID
if sess.ID == "" {
return ErrMissingSessionID
@ -1172,7 +1260,7 @@ func (s *StateStore) SessionDestroy(idx uint64, sessionID string) error {
defer tx.Abort()
// Call the session deletion
if err := s.sessionDestroyTxn(idx, sessionID, tx); err != nil {
if err := sessionDestroyTxn(tx, idx, sessionID); err != nil {
return err
}
@ -1182,7 +1270,7 @@ func (s *StateStore) SessionDestroy(idx uint64, sessionID string) error {
// sessionDestroyTxn is the inner method, which is used to do the actual
// session deletion and handle session invalidation, watch triggers, etc.
func (s *StateStore) sessionDestroyTxn(idx uint64, sessionID string, tx *memdb.Txn) error {
func sessionDestroyTxn(tx *memdb.Txn, idx uint64, sessionID string) error {
// Look up the session
sess, err := tx.First("sessions", "id", sessionID)
if err != nil {
@ -1211,17 +1299,18 @@ func (s *StateStore) ACLSet(idx uint64, acl *structs.ACL) error {
defer tx.Abort()
// Call set on the ACL
if err := s.aclSetTxn(idx, acl, tx); err != nil {
if err := aclSetTxn(tx, idx, acl); err != nil {
return err
}
tx.Defer(func() { s.GetWatchManager("acls").Notify() })
tx.Commit()
return nil
}
// aclSetTxn is the inner method used to insert an ACL rule with the
// proper indexes into the state store.
func (s *StateStore) aclSetTxn(idx uint64, acl *structs.ACL, tx *memdb.Txn) error {
func aclSetTxn(tx *memdb.Txn, idx uint64, acl *structs.ACL) error {
// Check that the ID is set
if acl.ID == "" {
return ErrMissingACLID
@ -1272,7 +1361,11 @@ func (s *StateStore) ACLGet(aclID string) (*structs.ACL, error) {
func (s *StateStore) ACLList() (uint64, []*structs.ACL, error) {
tx := s.db.Txn(false)
defer tx.Abort()
return aclListTxn(tx)
}
// aclListTxn is used to list out all of the ACLs in the state store.
func aclListTxn(tx *memdb.Txn) (uint64, []*structs.ACL, error) {
// Query all of the ACLs in the state store
acls, err := tx.Get("acls", "id")
if err != nil {
@ -1301,17 +1394,18 @@ func (s *StateStore) ACLDelete(idx uint64, aclID string) error {
defer tx.Abort()
// Call the ACL delete
if err := s.aclDeleteTxn(idx, aclID, tx); err != nil {
if err := aclDeleteTxn(tx, idx, aclID); err != nil {
return err
}
tx.Defer(func() { s.GetWatchManager("acls").Notify() })
tx.Commit()
return nil
}
// aclDeleteTxn is used to delete an ACL from the state store within
// an existing transaction.
func (s *StateStore) aclDeleteTxn(idx uint64, aclID string, tx *memdb.Txn) error {
func aclDeleteTxn(tx *memdb.Txn, idx uint64, aclID string) error {
// Look up the existing ACL
acl, err := tx.First("acls", "id", aclID)
if err != nil {
@ -1330,3 +1424,22 @@ func (s *StateStore) aclDeleteTxn(idx uint64, aclID string, tx *memdb.Txn) error
}
return nil
}
// ACLRestore is used when restoring from a snapshot. For general inserts, use
// ACLSet.
func (s *StateStore) ACLRestore(acl *structs.ACL) error {
tx := s.db.Txn(true)
defer tx.Abort()
if err := tx.Insert("acls", acl); err != nil {
return fmt.Errorf("failed restoring acl: %s", err)
}
if err := indexUpdateMaxTxn(tx, acl.ModifyIndex, "acls"); err != nil {
return err
}
tx.Defer(func() { s.GetWatchManager("acls").Notify() })
tx.Commit()
return nil
}

View File

@ -6,6 +6,7 @@ import (
"reflect"
"strings"
"testing"
"time"
"github.com/hashicorp/consul/consul/structs"
)
@ -31,7 +32,7 @@ func testRegisterNode(t *testing.T, s *StateStore, idx uint64, nodeID string) {
defer tx.Abort()
n, err := tx.First("nodes", "id", nodeID)
if err != nil {
t.Fatalf("err: %s", err, n)
t.Fatalf("err: %s", err)
}
if result, ok := n.(*structs.Node); !ok || result.Node != nodeID {
t.Fatalf("bad node: %#v", result)
@ -107,14 +108,33 @@ func testSetKey(t *testing.T, s *StateStore, idx uint64, key, value string) {
func TestStateStore_maxIndex(t *testing.T) {
s := testStateStore(t)
testRegisterNode(t, s, 0, "foo")
testRegisterNode(t, s, 1, "bar")
testRegisterService(t, s, 2, "foo", "consul")
if max := s.maxIndex("nodes", "services"); max != 2 {
t.Fatalf("bad max: %d", max)
}
}
func TestStateStore_indexUpdateMaxTxn(t *testing.T) {
s := testStateStore(t)
testRegisterNode(t, s, 0, "foo")
testRegisterNode(t, s, 1, "bar")
tx := s.db.Txn(true)
if err := indexUpdateMaxTxn(tx, 3, "nodes"); err != nil {
t.Fatalf("err: %s", err)
}
tx.Commit()
if max := s.maxIndex("nodes"); max != 3 {
t.Fatalf("bad max: %d", max)
}
}
func TestStateStore_EnsureNode(t *testing.T) {
s := testStateStore(t)
@ -1415,7 +1435,7 @@ func TestStateStore_SessionCreate_GetSession(t *testing.T) {
t.Fatalf("err: %s", err)
}
if idx := s.maxIndex("sessions"); idx != 2 {
t.Fatalf("bad index: %d", err)
t.Fatalf("bad index: %s", err)
}
// Retrieve the session again
@ -1814,3 +1834,44 @@ func TestStateStore_ACLDelete(t *testing.T) {
t.Fatalf("expected nil, got: %#v", result)
}
}
func TestStateStore_ACL_Watches(t *testing.T) {
s := testStateStore(t)
ch := make(chan struct{})
s.GetWatchManager("acls").Start(ch)
go func() {
if err := s.ACLSet(1, &structs.ACL{ID: "acl1"}); err != nil {
t.Fatalf("err: %s", err)
}
}()
select {
case <-ch:
case <-time.After(1 * time.Second):
t.Fatalf("watch was not notified")
}
s.GetWatchManager("acls").Start(ch)
go func() {
if err := s.ACLDelete(2, "acl1"); err != nil {
t.Fatalf("err: %s", err)
}
}()
select {
case <-ch:
case <-time.After(1 * time.Second):
t.Fatalf("watch was not notified")
}
s.GetWatchManager("acls").Start(ch)
go func() {
if err := s.ACLRestore(&structs.ACL{ID: "acl1"}); err != nil {
t.Fatalf("err: %s", err)
}
}()
select {
case <-ch:
case <-time.After(1 * time.Second):
t.Fatalf("watch was not notified")
}
}

35
consul/state/watch.go Normal file
View File

@ -0,0 +1,35 @@
package state
import (
"github.com/hashicorp/go-memdb"
)
type WatchManager interface {
Start(notifyCh chan struct{})
Stop(notifyCh chan struct{})
Notify()
}
type FullTableWatch struct {
notify NotifyGroup
}
func (w *FullTableWatch) Start(notifyCh chan struct{}) {
w.notify.Wait(notifyCh)
}
func (w *FullTableWatch) Stop(notifyCh chan struct{}) {
w.notify.Clear(notifyCh)
}
func (w *FullTableWatch) Notify() {
w.notify.Notify()
}
func newWatchManagers(schema *memdb.DBSchema) (map[string]WatchManager, error) {
watches := make(map[string]WatchManager)
for table, _ := range schema.Tables {
watches[table] = &FullTableWatch{}
}
return watches, nil
}

View File

@ -24,7 +24,6 @@ const (
dbTombstone = "tombstones"
dbSessions = "sessions"
dbSessionChecks = "sessionChecks"
dbACLs = "acls"
dbMaxMapSize32bit uint64 = 128 * 1024 * 1024 // 128MB maximum size
dbMaxMapSize64bit uint64 = 32 * 1024 * 1024 * 1024 // 32GB maximum size
dbMaxReaders uint = 4096 // 4K, default is 126
@ -59,7 +58,6 @@ type StateStore struct {
tombstoneTable *MDBTable
sessionTable *MDBTable
sessionCheckTable *MDBTable
aclTable *MDBTable
tables MDBTables
watch map[*MDBTable]*NotifyGroup
queryTables map[string]MDBTables
@ -361,27 +359,9 @@ func (s *StateStore) initialize() error {
},
}
s.aclTable = &MDBTable{
Name: dbACLs,
Indexes: map[string]*MDBIndex{
"id": &MDBIndex{
Unique: true,
Fields: []string{"ID"},
},
},
Decoder: func(buf []byte) interface{} {
out := new(structs.ACL)
if err := structs.Decode(buf, out); err != nil {
panic(err)
}
return out
},
}
// Store the set of tables
s.tables = []*MDBTable{s.nodeTable, s.serviceTable, s.checkTable,
s.kvsTable, s.tombstoneTable, s.sessionTable, s.sessionCheckTable,
s.aclTable}
s.kvsTable, s.tombstoneTable, s.sessionTable, s.sessionCheckTable}
for _, table := range s.tables {
table.Env = s.env
table.Encoder = encoder
@ -408,8 +388,6 @@ func (s *StateStore) initialize() error {
"SessionGet": MDBTables{s.sessionTable},
"SessionList": MDBTables{s.sessionTable},
"NodeSessions": MDBTables{s.sessionTable},
"ACLGet": MDBTables{s.aclTable},
"ACLList": MDBTables{s.aclTable},
}
return nil
}
@ -1945,109 +1923,6 @@ func (s *StateStore) deleteLocks(index uint64, tx *MDBTxn,
return nil
}
// ACLSet is used to create or update an ACL entry
func (s *StateStore) ACLSet(index uint64, acl *structs.ACL) error {
// Check for an ID
if acl.ID == "" {
return fmt.Errorf("Missing ACL ID")
}
// Start a new txn
tx, err := s.tables.StartTxn(false)
if err != nil {
return err
}
defer tx.Abort()
// Look for the existing node
res, err := s.aclTable.GetTxn(tx, "id", acl.ID)
if err != nil {
return err
}
switch len(res) {
case 0:
acl.CreateIndex = index
acl.ModifyIndex = index
case 1:
exist := res[0].(*structs.ACL)
acl.CreateIndex = exist.CreateIndex
acl.ModifyIndex = index
default:
panic(fmt.Errorf("Duplicate ACL definition. Internal error"))
}
// Insert the ACL
if err := s.aclTable.InsertTxn(tx, acl); err != nil {
return err
}
// Trigger the update notifications
if err := s.aclTable.SetLastIndexTxn(tx, index); err != nil {
return err
}
tx.Defer(func() { s.watch[s.aclTable].Notify() })
return tx.Commit()
}
// ACLRestore is used to restore an ACL. It should only be used when
// doing a restore, otherwise ACLSet should be used.
func (s *StateStore) ACLRestore(acl *structs.ACL) error {
// Start a new txn
tx, err := s.aclTable.StartTxn(false, nil)
if err != nil {
return err
}
defer tx.Abort()
if err := s.aclTable.InsertTxn(tx, acl); err != nil {
return err
}
if err := s.aclTable.SetMaxLastIndexTxn(tx, acl.ModifyIndex); err != nil {
return err
}
return tx.Commit()
}
// ACLGet is used to get an ACL by ID
func (s *StateStore) ACLGet(id string) (uint64, *structs.ACL, error) {
idx, res, err := s.aclTable.Get("id", id)
var d *structs.ACL
if len(res) > 0 {
d = res[0].(*structs.ACL)
}
return idx, d, err
}
// ACLList is used to list all the acls
func (s *StateStore) ACLList() (uint64, []*structs.ACL, error) {
idx, res, err := s.aclTable.Get("id")
out := make([]*structs.ACL, len(res))
for i, raw := range res {
out[i] = raw.(*structs.ACL)
}
return idx, out, err
}
// ACLDelete is used to remove an ACL
func (s *StateStore) ACLDelete(index uint64, id string) error {
tx, err := s.tables.StartTxn(false)
if err != nil {
panic(fmt.Errorf("Failed to start txn: %v", err))
}
defer tx.Abort()
if n, err := s.aclTable.DeleteTxn(tx, "id", id); err != nil {
return err
} else if n > 0 {
if err := s.aclTable.SetLastIndexTxn(tx, index); err != nil {
return err
}
tx.Defer(func() { s.watch[s.aclTable].Notify() })
}
return tx.Commit()
}
// Snapshot is used to create a point in time snapshot
func (s *StateStore) Snapshot() (*StateSnapshot, error) {
// Begin a new txn on all tables
@ -2128,13 +2003,3 @@ func (s *StateSnapshot) SessionList() ([]*structs.Session, error) {
}
return out, err
}
// ACLList is used to list all of the ACLs
func (s *StateSnapshot) ACLList() ([]*structs.ACL, error) {
res, err := s.store.aclTable.GetTxn(s.tx, "id")
out := make([]*structs.ACL, len(res))
for i, raw := range res {
out[i] = raw.(*structs.ACL)
}
return out, err
}

View File

@ -762,24 +762,6 @@ func TestStoreSnapshot(t *testing.T) {
t.Fatalf("err: %v", err)
}
a1 := &structs.ACL{
ID: generateUUID(),
Name: "User token",
Type: structs.ACLTypeClient,
}
if err := store.ACLSet(21, a1); err != nil {
t.Fatalf("err: %v", err)
}
a2 := &structs.ACL{
ID: generateUUID(),
Name: "User token",
Type: structs.ACLTypeClient,
}
if err := store.ACLSet(22, a2); err != nil {
t.Fatalf("err: %v", err)
}
// Take a snapshot
snap, err := store.Snapshot()
if err != nil {
@ -884,15 +866,6 @@ func TestStoreSnapshot(t *testing.T) {
t.Fatalf("Wrong number of sessions with TTL")
}
// Check for an acl
acls, err := snap.ACLList()
if err != nil {
t.Fatalf("err: %v", err)
}
if len(acls) != 2 {
t.Fatalf("missing acls")
}
// Make some changes!
if err := store.EnsureService(23, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8000}); err != nil {
t.Fatalf("err: %v", err)
@ -918,11 +891,6 @@ func TestStoreSnapshot(t *testing.T) {
t.Fatalf("err: %v", err)
}
// Nuke an ACL
if err := store.ACLDelete(29, a1.ID); err != nil {
t.Fatalf("err: %v", err)
}
// Check snapshot has old values
nodes = snap.Nodes()
if len(nodes) != 2 {
@ -1003,15 +971,6 @@ func TestStoreSnapshot(t *testing.T) {
if len(sessions) != 3 {
t.Fatalf("missing sessions")
}
// Check for an acl
acls, err = snap.ACLList()
if err != nil {
t.Fatalf("err: %v", err)
}
if len(acls) != 2 {
t.Fatalf("missing acls")
}
}
func TestEnsureCheck(t *testing.T) {
@ -2880,148 +2839,3 @@ func TestSessionInvalidate_KeyDelete(t *testing.T) {
t.Fatalf("Bad: %v", expires)
}
}
func TestACLSet_Get(t *testing.T) {
store, err := testStateStore()
if err != nil {
t.Fatalf("err: %v", err)
}
defer store.Close()
idx, out, err := store.ACLGet("1234")
if err != nil {
t.Fatalf("err: %v", err)
}
if idx != 0 {
t.Fatalf("bad: %v", idx)
}
if out != nil {
t.Fatalf("bad: %v", out)
}
a := &structs.ACL{
ID: generateUUID(),
Name: "User token",
Type: structs.ACLTypeClient,
Rules: "",
}
if err := store.ACLSet(50, a); err != nil {
t.Fatalf("err: %v", err)
}
if a.CreateIndex != 50 {
t.Fatalf("Bad: %v", a)
}
if a.ModifyIndex != 50 {
t.Fatalf("Bad: %v", a)
}
if a.ID == "" {
t.Fatalf("Bad: %v", a)
}
idx, out, err = store.ACLGet(a.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if idx != 50 {
t.Fatalf("bad: %v", idx)
}
if !reflect.DeepEqual(out, a) {
t.Fatalf("bad: %v", out)
}
// Update
a.Rules = "foo bar baz"
if err := store.ACLSet(52, a); err != nil {
t.Fatalf("err: %v", err)
}
if a.CreateIndex != 50 {
t.Fatalf("Bad: %v", a)
}
if a.ModifyIndex != 52 {
t.Fatalf("Bad: %v", a)
}
idx, out, err = store.ACLGet(a.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if idx != 52 {
t.Fatalf("bad: %v", idx)
}
if !reflect.DeepEqual(out, a) {
t.Fatalf("bad: %v", out)
}
}
func TestACLDelete(t *testing.T) {
store, err := testStateStore()
if err != nil {
t.Fatalf("err: %v", err)
}
defer store.Close()
a := &structs.ACL{
ID: generateUUID(),
Name: "User token",
Type: structs.ACLTypeClient,
Rules: "",
}
if err := store.ACLSet(50, a); err != nil {
t.Fatalf("err: %v", err)
}
if err := store.ACLDelete(52, a.ID); err != nil {
t.Fatalf("err: %v", err)
}
if err := store.ACLDelete(53, a.ID); err != nil {
t.Fatalf("err: %v", err)
}
idx, out, err := store.ACLGet(a.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if idx != 52 {
t.Fatalf("bad: %v", idx)
}
if out != nil {
t.Fatalf("bad: %v", out)
}
}
func TestACLList(t *testing.T) {
store, err := testStateStore()
if err != nil {
t.Fatalf("err: %v", err)
}
defer store.Close()
a1 := &structs.ACL{
ID: generateUUID(),
Name: "User token",
Type: structs.ACLTypeClient,
}
if err := store.ACLSet(50, a1); err != nil {
t.Fatalf("err: %v", err)
}
a2 := &structs.ACL{
ID: generateUUID(),
Name: "User token",
Type: structs.ACLTypeClient,
}
if err := store.ACLSet(51, a2); err != nil {
t.Fatalf("err: %v", err)
}
idx, out, err := store.ACLList()
if err != nil {
t.Fatalf("err: %v", err)
}
if idx != 51 {
t.Fatalf("bad: %v", idx)
}
if len(out) != 2 {
t.Fatalf("bad: %v", out)
}
}