mirror of
https://github.com/status-im/consul.git
synced 2025-01-10 13:55:55 +00:00
Integrates new state store for ACLs.
This commit is contained in:
parent
d57431e300
commit
009fd7d9f5
@ -51,8 +51,8 @@ type aclCacheEntry struct {
|
||||
// aclFault is used to fault in the rules for an ACL if we take a miss
|
||||
func (s *Server) aclFault(id string) (string, string, error) {
|
||||
defer metrics.MeasureSince([]string{"consul", "acl", "fault"}, time.Now())
|
||||
state := s.fsm.State()
|
||||
_, acl, err := state.ACLGet(id)
|
||||
state := s.fsm.StateNew()
|
||||
acl, err := state.ACLGet(id)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -60,10 +60,10 @@ func (a *ACL) Apply(args *structs.ACLRequest, reply *string) error {
|
||||
// deterministic. Once the entry is in the log, the state update MUST
|
||||
// be deterministic or the followers will not converge.
|
||||
if args.ACL.ID == "" {
|
||||
state := a.srv.fsm.State()
|
||||
state := a.srv.fsm.StateNew()
|
||||
for {
|
||||
args.ACL.ID = generateUUID()
|
||||
_, acl, err := state.ACLGet(args.ACL.ID)
|
||||
acl, err := state.ACLGet(args.ACL.ID)
|
||||
if err != nil {
|
||||
a.srv.logger.Printf("[ERR] consul.acl: ACL lookup failed: %v", err)
|
||||
return err
|
||||
@ -120,14 +120,14 @@ func (a *ACL) Get(args *structs.ACLSpecificRequest,
|
||||
}
|
||||
|
||||
// Get the local state
|
||||
state := a.srv.fsm.State()
|
||||
return a.srv.blockingRPC(&args.QueryOptions,
|
||||
state := a.srv.fsm.StateNew()
|
||||
return a.srv.blockingRPCNew(&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
state.QueryTables("ACLGet"),
|
||||
state.GetWatchManager("acls"),
|
||||
func() error {
|
||||
index, acl, err := state.ACLGet(args.ACL)
|
||||
reply.Index = index
|
||||
acl, err := state.ACLGet(args.ACL)
|
||||
if acl != nil {
|
||||
reply.Index = acl.ModifyIndex
|
||||
reply.ACLs = structs.ACLs{acl}
|
||||
} else {
|
||||
reply.ACLs = nil
|
||||
@ -191,10 +191,10 @@ func (a *ACL) List(args *structs.DCSpecificRequest,
|
||||
}
|
||||
|
||||
// Get the local state
|
||||
state := a.srv.fsm.State()
|
||||
return a.srv.blockingRPC(&args.QueryOptions,
|
||||
state := a.srv.fsm.StateNew()
|
||||
return a.srv.blockingRPCNew(&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
state.QueryTables("ACLList"),
|
||||
state.GetWatchManager("acls"),
|
||||
func() error {
|
||||
var err error
|
||||
reply.Index, reply.ACLs, err = state.ACLList()
|
||||
|
@ -39,8 +39,8 @@ func TestACLEndpoint_Apply(t *testing.T) {
|
||||
id := out
|
||||
|
||||
// Verify
|
||||
state := s1.fsm.State()
|
||||
_, s, err := state.ACLGet(out)
|
||||
state := s1.fsm.StateNew()
|
||||
s, err := state.ACLGet(out)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
@ -62,7 +62,7 @@ func TestACLEndpoint_Apply(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify
|
||||
_, s, err = state.ACLGet(id)
|
||||
s, err = state.ACLGet(id)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
@ -180,8 +180,8 @@ func TestACLEndpoint_Apply_CustomID(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify
|
||||
state := s1.fsm.State()
|
||||
_, s, err := state.ACLGet(out)
|
||||
state := s1.fsm.StateNew()
|
||||
s, err := state.ACLGet(out)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/consul/state"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/go-msgpack/codec"
|
||||
"github.com/hashicorp/raft"
|
||||
@ -24,6 +25,7 @@ type consulFSM struct {
|
||||
logOutput io.Writer
|
||||
logger *log.Logger
|
||||
path string
|
||||
stateNew *state.StateStore
|
||||
state *StateStore
|
||||
gc *TombstoneGC
|
||||
}
|
||||
@ -33,6 +35,7 @@ type consulFSM struct {
|
||||
// that may modify the live state.
|
||||
type consulSnapshot struct {
|
||||
state *StateSnapshot
|
||||
stateNew *state.StateSnapshot
|
||||
}
|
||||
|
||||
// snapshotHeader is the first entry in our snapshot
|
||||
@ -44,6 +47,12 @@ type snapshotHeader struct {
|
||||
|
||||
// NewFSMPath is used to construct a new FSM with a blank state
|
||||
func NewFSM(gc *TombstoneGC, path string, logOutput io.Writer) (*consulFSM, error) {
|
||||
// Create the state store.
|
||||
stateNew, err := state.NewStateStore(logOutput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a temporary path for the state store
|
||||
tmpPath, err := ioutil.TempDir(path, "state")
|
||||
if err != nil {
|
||||
@ -60,6 +69,7 @@ func NewFSM(gc *TombstoneGC, path string, logOutput io.Writer) (*consulFSM, erro
|
||||
logOutput: logOutput,
|
||||
logger: log.New(logOutput, "", log.LstdFlags),
|
||||
path: path,
|
||||
stateNew: stateNew,
|
||||
state: state,
|
||||
gc: gc,
|
||||
}
|
||||
@ -71,6 +81,11 @@ func (c *consulFSM) Close() error {
|
||||
return c.state.Close()
|
||||
}
|
||||
|
||||
// TODO(slackpad)
|
||||
func (c *consulFSM) StateNew() *state.StateStore {
|
||||
return c.stateNew
|
||||
}
|
||||
|
||||
// State is used to return a handle to the current state
|
||||
func (c *consulFSM) State() *StateStore {
|
||||
return c.state
|
||||
@ -234,13 +249,13 @@ func (c *consulFSM) applyACLOperation(buf []byte, index uint64) interface{} {
|
||||
defer metrics.MeasureSince([]string{"consul", "fsm", "acl", string(req.Op)}, time.Now())
|
||||
switch req.Op {
|
||||
case structs.ACLForceSet, structs.ACLSet:
|
||||
if err := c.state.ACLSet(index, &req.ACL); err != nil {
|
||||
if err := c.stateNew.ACLSet(index, &req.ACL); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return req.ACL.ID
|
||||
}
|
||||
case structs.ACLDelete:
|
||||
return c.state.ACLDelete(index, req.ACL.ID)
|
||||
return c.stateNew.ACLDelete(index, req.ACL.ID)
|
||||
default:
|
||||
c.logger.Printf("[WARN] consul.fsm: Invalid ACL operation '%s'", req.Op)
|
||||
return fmt.Errorf("Invalid ACL operation '%s'", req.Op)
|
||||
@ -272,7 +287,7 @@ func (c *consulFSM) Snapshot() (raft.FSMSnapshot, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &consulSnapshot{snap}, nil
|
||||
return &consulSnapshot{snap, c.stateNew.Snapshot()}, nil
|
||||
}
|
||||
|
||||
func (c *consulFSM) Restore(old io.ReadCloser) error {
|
||||
@ -344,7 +359,7 @@ func (c *consulFSM) Restore(old io.ReadCloser) error {
|
||||
if err := dec.Decode(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.state.ACLRestore(&req); err != nil {
|
||||
if err := c.stateNew.ACLRestore(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -467,7 +482,7 @@ func (s *consulSnapshot) persistSessions(sink raft.SnapshotSink,
|
||||
|
||||
func (s *consulSnapshot) persistACLs(sink raft.SnapshotSink,
|
||||
encoder *codec.Encoder) error {
|
||||
acls, err := s.state.ACLList()
|
||||
acls, err := s.stateNew.ACLList()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -361,7 +361,7 @@ func TestFSM_SnapshotRestore(t *testing.T) {
|
||||
session := &structs.Session{ID: generateUUID(), Node: "foo"}
|
||||
fsm.state.SessionCreate(9, session)
|
||||
acl := &structs.ACL{ID: generateUUID(), Name: "User Token"}
|
||||
fsm.state.ACLSet(10, acl)
|
||||
fsm.stateNew.ACLSet(10, acl)
|
||||
|
||||
fsm.state.KVSSet(11, &structs.DirEntry{
|
||||
Key: "/remove",
|
||||
@ -448,14 +448,14 @@ func TestFSM_SnapshotRestore(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify ACL is restored
|
||||
idx, a, err := fsm2.state.ACLGet(acl.ID)
|
||||
a, err := fsm2.stateNew.ACLGet(acl.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if a.Name != "User Token" {
|
||||
t.Fatalf("bad: %v", a)
|
||||
}
|
||||
if idx <= 1 {
|
||||
if a.ModifyIndex <= 1 {
|
||||
t.Fatalf("bad index: %d", idx)
|
||||
}
|
||||
|
||||
@ -971,7 +971,7 @@ func TestFSM_ACL_Set_Delete(t *testing.T) {
|
||||
|
||||
// Get the ACL
|
||||
id := resp.(string)
|
||||
_, acl, err := fsm.state.ACLGet(id)
|
||||
acl, err := fsm.stateNew.ACLGet(id)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
@ -1007,7 +1007,7 @@ func TestFSM_ACL_Set_Delete(t *testing.T) {
|
||||
t.Fatalf("resp: %v", resp)
|
||||
}
|
||||
|
||||
_, acl, err = fsm.state.ACLGet(id)
|
||||
acl, err = fsm.stateNew.ACLGet(id)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -182,8 +182,8 @@ func (s *Server) initializeACL() error {
|
||||
s.aclAuthCache.Purge()
|
||||
|
||||
// Look for the anonymous token
|
||||
state := s.fsm.State()
|
||||
_, acl, err := state.ACLGet(anonymousToken)
|
||||
state := s.fsm.StateNew()
|
||||
acl, err := state.ACLGet(anonymousToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get anonymous token: %v", err)
|
||||
}
|
||||
@ -212,7 +212,7 @@ func (s *Server) initializeACL() error {
|
||||
}
|
||||
|
||||
// Look for the master token
|
||||
_, acl, err = state.ACLGet(master)
|
||||
acl, err = state.ACLGet(master)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get master token: %v", err)
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/consul/state"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/hashicorp/yamux"
|
||||
@ -397,6 +398,75 @@ RUN_QUERY:
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(slackpad)
|
||||
func (s *Server) blockingRPCNew(queryOpts *structs.QueryOptions, queryMeta *structs.QueryMeta,
|
||||
watch state.WatchManager, run func() error) error {
|
||||
var timeout *time.Timer
|
||||
var notifyCh chan struct{}
|
||||
|
||||
// Fast path right to the non-blocking query.
|
||||
if queryOpts.MinQueryIndex == 0 {
|
||||
goto RUN_QUERY
|
||||
}
|
||||
|
||||
// Make sure a watch manager was given if we were asked to block.
|
||||
if watch == nil {
|
||||
panic("no watch manager given for blocking query")
|
||||
}
|
||||
|
||||
// Restrict the max query time, and ensure there is always one.
|
||||
if queryOpts.MaxQueryTime > maxQueryTime {
|
||||
queryOpts.MaxQueryTime = maxQueryTime
|
||||
} else if queryOpts.MaxQueryTime <= 0 {
|
||||
queryOpts.MaxQueryTime = defaultQueryTime
|
||||
}
|
||||
|
||||
// Apply a small amount of jitter to the request.
|
||||
queryOpts.MaxQueryTime += randomStagger(queryOpts.MaxQueryTime / jitterFraction)
|
||||
|
||||
// Setup a query timeout.
|
||||
timeout = time.NewTimer(queryOpts.MaxQueryTime)
|
||||
|
||||
// Setup the notify channel.
|
||||
notifyCh = make(chan struct{}, 1)
|
||||
|
||||
// Ensure we tear down any watches on return.
|
||||
defer func() {
|
||||
timeout.Stop()
|
||||
watch.Stop(notifyCh)
|
||||
}()
|
||||
|
||||
REGISTER_NOTIFY:
|
||||
// Register the notification channel. This may be done multiple times if
|
||||
// we haven't reached the target wait index.
|
||||
watch.Start(notifyCh)
|
||||
|
||||
RUN_QUERY:
|
||||
// Update the query metadata.
|
||||
s.setQueryMeta(queryMeta)
|
||||
|
||||
// If the read must be consistent we verify that we are still the leader.
|
||||
if queryOpts.RequireConsistent {
|
||||
if err := s.consistentRead(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Run the query.
|
||||
metrics.IncrCounter([]string{"consul", "rpc", "query"}, 1)
|
||||
err := run()
|
||||
|
||||
// Check for minimum query time.
|
||||
if err == nil && queryMeta.Index > 0 && queryMeta.Index <= queryOpts.MinQueryIndex {
|
||||
select {
|
||||
case <-notifyCh:
|
||||
goto REGISTER_NOTIFY
|
||||
case <-timeout.C:
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// setQueryMeta is used to populate the QueryMeta data for an RPC call
|
||||
func (s *Server) setQueryMeta(m *structs.QueryMeta) {
|
||||
if s.IsLeader() {
|
||||
|
55
consul/state/notify.go
Normal file
55
consul/state/notify.go
Normal file
@ -0,0 +1,55 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// NotifyGroup is used to allow a simple notification mechanism.
|
||||
// Channels can be marked as waiting, and when notify is invoked,
|
||||
// all the waiting channels get a message and are cleared from the
|
||||
// notify list.
|
||||
type NotifyGroup struct {
|
||||
l sync.Mutex
|
||||
notify map[chan struct{}]struct{}
|
||||
}
|
||||
|
||||
// Notify will do a non-blocking send to all waiting channels, and
|
||||
// clear the notify list
|
||||
func (n *NotifyGroup) Notify() {
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
for ch, _ := range n.notify {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
n.notify = nil
|
||||
}
|
||||
|
||||
// Wait adds a channel to the notify group
|
||||
func (n *NotifyGroup) Wait(ch chan struct{}) {
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
if n.notify == nil {
|
||||
n.notify = make(map[chan struct{}]struct{})
|
||||
}
|
||||
n.notify[ch] = struct{}{}
|
||||
}
|
||||
|
||||
// Clear removes a channel from the notify group
|
||||
func (n *NotifyGroup) Clear(ch chan struct{}) {
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
if n.notify == nil {
|
||||
return
|
||||
}
|
||||
delete(n.notify, ch)
|
||||
}
|
||||
|
||||
// WaitCh allocates a channel that is subscribed to notifications
|
||||
func (n *NotifyGroup) WaitCh() chan struct{} {
|
||||
ch := make(chan struct{}, 1)
|
||||
n.Wait(ch)
|
||||
return ch
|
||||
}
|
72
consul/state/notify_test.go
Normal file
72
consul/state/notify_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNotifyGroup(t *testing.T) {
|
||||
grp := &NotifyGroup{}
|
||||
|
||||
ch1 := grp.WaitCh()
|
||||
ch2 := grp.WaitCh()
|
||||
|
||||
select {
|
||||
case <-ch1:
|
||||
t.Fatalf("should block")
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-ch2:
|
||||
t.Fatalf("should block")
|
||||
default:
|
||||
}
|
||||
|
||||
grp.Notify()
|
||||
|
||||
select {
|
||||
case <-ch1:
|
||||
default:
|
||||
t.Fatalf("should not block")
|
||||
}
|
||||
select {
|
||||
case <-ch2:
|
||||
default:
|
||||
t.Fatalf("should not block")
|
||||
}
|
||||
|
||||
// Should be unregistered
|
||||
ch3 := grp.WaitCh()
|
||||
grp.Notify()
|
||||
|
||||
select {
|
||||
case <-ch1:
|
||||
t.Fatalf("should block")
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-ch2:
|
||||
t.Fatalf("should block")
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-ch3:
|
||||
default:
|
||||
t.Fatalf("should not block")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifyGroup_Clear(t *testing.T) {
|
||||
grp := &NotifyGroup{}
|
||||
|
||||
ch1 := grp.WaitCh()
|
||||
grp.Clear(ch1)
|
||||
|
||||
grp.Notify()
|
||||
|
||||
// Should not get anything
|
||||
select {
|
||||
case <-ch1:
|
||||
t.Fatalf("should not get message")
|
||||
default:
|
||||
}
|
||||
}
|
@ -35,7 +35,16 @@ var (
|
||||
// from the Raft log through the FSM.
|
||||
type StateStore struct {
|
||||
logger *log.Logger // TODO(slackpad) - Delete if unused!
|
||||
schema *memdb.DBSchema
|
||||
db *memdb.MemDB
|
||||
watches map[string]WatchManager
|
||||
}
|
||||
|
||||
// StateSnapshot is used to provide a point-in-time snapshot. It
|
||||
// works by starting a read transaction against the whole state store.
|
||||
type StateSnapshot struct {
|
||||
tx *memdb.Txn
|
||||
lastIndex uint64
|
||||
}
|
||||
|
||||
// IndexEntry keeps a record of the last index per-table.
|
||||
@ -56,26 +65,69 @@ type sessionCheck struct {
|
||||
|
||||
// NewStateStore creates a new in-memory state storage layer.
|
||||
func NewStateStore(logOutput io.Writer) (*StateStore, error) {
|
||||
// Create the in-memory DB
|
||||
db, err := memdb.NewMemDB(stateStoreSchema())
|
||||
// Create the in-memory DB.
|
||||
schema := stateStoreSchema()
|
||||
db, err := memdb.NewMemDB(schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed setting up state store: %s", err)
|
||||
}
|
||||
|
||||
// Create and return the state store
|
||||
// Build up the watch managers.
|
||||
watches, err := newWatchManagers(schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to build watch managers: %s", err)
|
||||
}
|
||||
|
||||
// Create and return the state store.
|
||||
s := &StateStore{
|
||||
logger: log.New(logOutput, "", log.LstdFlags),
|
||||
schema: schema,
|
||||
db: db,
|
||||
watches: watches,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Snapshot is used to create a point-in-time snapshot of the entire db.
|
||||
func (s *StateStore) Snapshot() *StateSnapshot {
|
||||
tx := s.db.Txn(false)
|
||||
|
||||
var tables []string
|
||||
for table, _ := range s.schema.Tables {
|
||||
tables = append(tables, table)
|
||||
}
|
||||
idx := maxIndexTxn(tx, tables...)
|
||||
|
||||
return &StateSnapshot{tx, idx}
|
||||
}
|
||||
|
||||
// LastIndex returns that last index that affects the snapshotted data.
|
||||
func (s *StateSnapshot) LastIndex() uint64 {
|
||||
return s.lastIndex
|
||||
}
|
||||
|
||||
// Close performs cleanup of a state snapshot.
|
||||
func (s *StateSnapshot) Close() {
|
||||
s.tx.Abort()
|
||||
}
|
||||
|
||||
// ACLList is used to pull all the ACLs from the snapshot.
|
||||
func (s *StateSnapshot) ACLList() ([]*structs.ACL, error) {
|
||||
_, ret, err := aclListTxn(s.tx)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// maxIndex is a helper used to retrieve the highest known index
|
||||
// amongst a set of tables in the db.
|
||||
func (s *StateStore) maxIndex(tables ...string) uint64 {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
return maxIndexTxn(tx, tables...)
|
||||
}
|
||||
|
||||
// maxIndexTxn is a helper used to retrieve the highest known index
|
||||
// amongst a set of tables in the db.
|
||||
func maxIndexTxn(tx *memdb.Txn, tables ...string) uint64 {
|
||||
var lindex uint64
|
||||
for _, table := range tables {
|
||||
ti, err := tx.First("index", "id", table)
|
||||
@ -89,13 +141,51 @@ func (s *StateStore) maxIndex(tables ...string) uint64 {
|
||||
return lindex
|
||||
}
|
||||
|
||||
// indexUpdateMaxTxn is used when restoring entries and sets the table's index to
|
||||
// the given idx only if it's greater than the current index.
|
||||
func indexUpdateMaxTxn(tx *memdb.Txn, idx uint64, table string) error {
|
||||
raw, err := tx.First("index", "id", table)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve existing index: %s", err)
|
||||
}
|
||||
|
||||
if raw == nil {
|
||||
return fmt.Errorf("missing index for table %s", table)
|
||||
}
|
||||
|
||||
entry, ok := raw.(*IndexEntry)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected index type for table %s", table)
|
||||
}
|
||||
|
||||
if idx > entry.Value {
|
||||
if err := tx.Insert("index", &IndexEntry{table, idx}); err != nil {
|
||||
return fmt.Errorf("failed updating index %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getWatchManager returns a watch manager for the given set of tables. The
|
||||
// order of the tables is not important.
|
||||
func (s *StateStore) GetWatchManager(tables ...string) WatchManager {
|
||||
if len(tables) == 1 {
|
||||
if manager, ok := s.watches[tables[0]]; ok {
|
||||
return manager
|
||||
}
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("Unknown watch manager(s): %v", tables))
|
||||
}
|
||||
|
||||
// EnsureNode is used to upsert node registration or modification.
|
||||
func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
// Call the node upsert
|
||||
if err := s.ensureNodeTxn(idx, node, tx); err != nil {
|
||||
if err := ensureNodeTxn(tx, idx, node); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -106,7 +196,7 @@ func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error {
|
||||
// ensureNodeTxn is the inner function called to actually create a node
|
||||
// registration or modify an existing one in the state store. It allows
|
||||
// passing in a memdb transaction so it may be part of a larger txn.
|
||||
func (s *StateStore) ensureNodeTxn(idx uint64, node *structs.Node, tx *memdb.Txn) error {
|
||||
func ensureNodeTxn(tx *memdb.Txn, idx uint64, node *structs.Node) error {
|
||||
// Check for an existing node
|
||||
existing, err := tx.First("nodes", "id", node.Node)
|
||||
if err != nil {
|
||||
@ -179,7 +269,7 @@ func (s *StateStore) DeleteNode(idx uint64, nodeID string) error {
|
||||
defer tx.Abort()
|
||||
|
||||
// Call the node deletion.
|
||||
if err := s.deleteNodeTxn(idx, nodeID, tx); err != nil {
|
||||
if err := deleteNodeTxn(tx, idx, nodeID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -189,7 +279,7 @@ func (s *StateStore) DeleteNode(idx uint64, nodeID string) error {
|
||||
|
||||
// deleteNodeTxn is the inner method used for removing a node from
|
||||
// the store within a given transaction.
|
||||
func (s *StateStore) deleteNodeTxn(idx uint64, nodeID string, tx *memdb.Txn) error {
|
||||
func deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeID string) error {
|
||||
// Look up the node
|
||||
node, err := tx.First("nodes", "id", nodeID)
|
||||
if err != nil {
|
||||
@ -206,7 +296,7 @@ func (s *StateStore) deleteNodeTxn(idx uint64, nodeID string, tx *memdb.Txn) err
|
||||
}
|
||||
for service := services.Next(); service != nil; service = services.Next() {
|
||||
svc := service.(*structs.ServiceNode)
|
||||
if err := s.deleteServiceTxn(idx, nodeID, svc.ServiceID, tx); err != nil {
|
||||
if err := deleteServiceTxn(tx, idx, nodeID, svc.ServiceID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -218,7 +308,7 @@ func (s *StateStore) deleteNodeTxn(idx uint64, nodeID string, tx *memdb.Txn) err
|
||||
}
|
||||
for check := checks.Next(); check != nil; check = checks.Next() {
|
||||
chk := check.(*structs.HealthCheck)
|
||||
if err := s.deleteCheckTxn(idx, nodeID, chk.CheckID, tx); err != nil {
|
||||
if err := deleteCheckTxn(tx, idx, nodeID, chk.CheckID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -242,7 +332,7 @@ func (s *StateStore) EnsureService(idx uint64, node string, svc *structs.NodeSer
|
||||
defer tx.Abort()
|
||||
|
||||
// Call the service registration upsert
|
||||
if err := s.ensureServiceTxn(idx, node, svc, tx); err != nil {
|
||||
if err := ensureServiceTxn(tx, idx, node, svc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -252,7 +342,7 @@ func (s *StateStore) EnsureService(idx uint64, node string, svc *structs.NodeSer
|
||||
|
||||
// ensureServiceTxn is used to upsert a service registration within an
|
||||
// existing memdb transaction.
|
||||
func (s *StateStore) ensureServiceTxn(idx uint64, node string, svc *structs.NodeService, tx *memdb.Txn) error {
|
||||
func ensureServiceTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeService) error {
|
||||
// Check for existing service
|
||||
existing, err := tx.First("services", "id", node, svc.Service)
|
||||
if err != nil {
|
||||
@ -358,7 +448,7 @@ func (s *StateStore) DeleteService(idx uint64, nodeID, serviceID string) error {
|
||||
defer tx.Abort()
|
||||
|
||||
// Call the service deletion
|
||||
if err := s.deleteServiceTxn(idx, nodeID, serviceID, tx); err != nil {
|
||||
if err := deleteServiceTxn(tx, idx, nodeID, serviceID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -368,7 +458,7 @@ func (s *StateStore) DeleteService(idx uint64, nodeID, serviceID string) error {
|
||||
|
||||
// deleteServiceTxn is the inner method called to remove a service
|
||||
// registration within an existing transaction.
|
||||
func (s *StateStore) deleteServiceTxn(idx uint64, nodeID, serviceID string, tx *memdb.Txn) error {
|
||||
func deleteServiceTxn(tx *memdb.Txn, idx uint64, nodeID, serviceID string) error {
|
||||
// Look up the service
|
||||
service, err := tx.First("services", "id", nodeID, serviceID)
|
||||
if err != nil {
|
||||
@ -411,7 +501,7 @@ func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error {
|
||||
defer tx.Abort()
|
||||
|
||||
// Call the check registration
|
||||
if err := s.ensureCheckTxn(idx, hc, tx); err != nil {
|
||||
if err := ensureCheckTxn(tx, idx, hc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -422,7 +512,7 @@ func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error {
|
||||
// ensureCheckTransaction is used as the inner method to handle inserting
|
||||
// a health check into the state store. It ensures safety against inserting
|
||||
// checks with no matching node or service.
|
||||
func (s *StateStore) ensureCheckTxn(idx uint64, hc *structs.HealthCheck, tx *memdb.Txn) error {
|
||||
func ensureCheckTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthCheck) error {
|
||||
// Check if we have an existing health check
|
||||
existing, err := tx.First("checks", "id", hc.Node, hc.CheckID)
|
||||
if err != nil {
|
||||
@ -541,7 +631,7 @@ func (s *StateStore) DeleteCheck(idx uint64, node, id string) error {
|
||||
defer tx.Abort()
|
||||
|
||||
// Call the check deletion
|
||||
if err := s.deleteCheckTxn(idx, node, id, tx); err != nil {
|
||||
if err := deleteCheckTxn(tx, idx, node, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -551,7 +641,7 @@ func (s *StateStore) DeleteCheck(idx uint64, node, id string) error {
|
||||
|
||||
// deleteCheckTxn is the inner method used to call a health
|
||||
// check deletion within an existing transaction.
|
||||
func (s *StateStore) deleteCheckTxn(idx uint64, node, id string, tx *memdb.Txn) error {
|
||||
func deleteCheckTxn(tx *memdb.Txn, idx uint64, node, id string) error {
|
||||
// Try to retrieve the existing health check
|
||||
check, err := tx.First("checks", "id", node, id)
|
||||
if err != nil {
|
||||
@ -743,14 +833,12 @@ func (s *StateStore) parseNodes(
|
||||
func (s *StateStore) KVSSet(idx uint64, entry *structs.DirEntry) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
return s.kvsSetTxn(idx, entry, tx)
|
||||
return kvsSetTxn(tx, idx, entry)
|
||||
}
|
||||
|
||||
// kvsSetTxn is used to insert or update a key/value pair in the state
|
||||
// store. It is the inner method used and handles only the actual storage.
|
||||
func (s *StateStore) kvsSetTxn(
|
||||
idx uint64, entry *structs.DirEntry,
|
||||
tx *memdb.Txn) error {
|
||||
func kvsSetTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) error {
|
||||
|
||||
// Retrieve an existing KV pair
|
||||
existing, err := tx.First("kvs", "id", entry.Key)
|
||||
@ -878,7 +966,7 @@ func (s *StateStore) KVSDelete(idx uint64, key string) error {
|
||||
defer tx.Abort()
|
||||
|
||||
// Perform the actual delete
|
||||
if err := s.kvsDeleteTxn(idx, key, tx); err != nil {
|
||||
if err := kvsDeleteTxn(tx, idx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -888,7 +976,7 @@ func (s *StateStore) KVSDelete(idx uint64, key string) error {
|
||||
|
||||
// kvsDeleteTxn is the inner method used to perform the actual deletion
|
||||
// of a key/value pair within an existing transaction.
|
||||
func (s *StateStore) kvsDeleteTxn(idx uint64, key string, tx *memdb.Txn) error {
|
||||
func kvsDeleteTxn(tx *memdb.Txn, idx uint64, key string) error {
|
||||
// Look up the entry in the state store
|
||||
entry, err := tx.First("kvs", "id", key)
|
||||
if err != nil {
|
||||
@ -931,7 +1019,7 @@ func (s *StateStore) KVSDeleteCAS(idx, cidx uint64, key string) (bool, error) {
|
||||
}
|
||||
|
||||
// Call the actual deletion if the above passed
|
||||
if err := s.kvsDeleteTxn(idx, key, tx); err != nil {
|
||||
if err := kvsDeleteTxn(tx, idx, key); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@ -967,7 +1055,7 @@ func (s *StateStore) KVSSetCAS(idx uint64, entry *structs.DirEntry) (bool, error
|
||||
}
|
||||
|
||||
// If we made it this far, we should perform the set.
|
||||
return true, s.kvsSetTxn(idx, entry, tx)
|
||||
return true, kvsSetTxn(tx, idx, entry)
|
||||
}
|
||||
|
||||
// KVSDeleteTree is used to do a recursive delete on a key prefix
|
||||
@ -1011,7 +1099,7 @@ func (s *StateStore) SessionCreate(idx uint64, sess *structs.Session) error {
|
||||
defer tx.Abort()
|
||||
|
||||
// Call the session creation
|
||||
if err := s.sessionCreateTxn(idx, sess, tx); err != nil {
|
||||
if err := sessionCreateTxn(tx, idx, sess); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1022,7 +1110,7 @@ func (s *StateStore) SessionCreate(idx uint64, sess *structs.Session) error {
|
||||
// sessionCreateTxn is the inner method used for creating session entries in
|
||||
// an open transaction. Any health checks registered with the session will be
|
||||
// checked for failing status. Returns any error encountered.
|
||||
func (s *StateStore) sessionCreateTxn(idx uint64, sess *structs.Session, tx *memdb.Txn) error {
|
||||
func sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Session) error {
|
||||
// Check that we have a session ID
|
||||
if sess.ID == "" {
|
||||
return ErrMissingSessionID
|
||||
@ -1172,7 +1260,7 @@ func (s *StateStore) SessionDestroy(idx uint64, sessionID string) error {
|
||||
defer tx.Abort()
|
||||
|
||||
// Call the session deletion
|
||||
if err := s.sessionDestroyTxn(idx, sessionID, tx); err != nil {
|
||||
if err := sessionDestroyTxn(tx, idx, sessionID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1182,7 +1270,7 @@ func (s *StateStore) SessionDestroy(idx uint64, sessionID string) error {
|
||||
|
||||
// sessionDestroyTxn is the inner method, which is used to do the actual
|
||||
// session deletion and handle session invalidation, watch triggers, etc.
|
||||
func (s *StateStore) sessionDestroyTxn(idx uint64, sessionID string, tx *memdb.Txn) error {
|
||||
func sessionDestroyTxn(tx *memdb.Txn, idx uint64, sessionID string) error {
|
||||
// Look up the session
|
||||
sess, err := tx.First("sessions", "id", sessionID)
|
||||
if err != nil {
|
||||
@ -1211,17 +1299,18 @@ func (s *StateStore) ACLSet(idx uint64, acl *structs.ACL) error {
|
||||
defer tx.Abort()
|
||||
|
||||
// Call set on the ACL
|
||||
if err := s.aclSetTxn(idx, acl, tx); err != nil {
|
||||
if err := aclSetTxn(tx, idx, acl); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Defer(func() { s.GetWatchManager("acls").Notify() })
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// aclSetTxn is the inner method used to insert an ACL rule with the
|
||||
// proper indexes into the state store.
|
||||
func (s *StateStore) aclSetTxn(idx uint64, acl *structs.ACL, tx *memdb.Txn) error {
|
||||
func aclSetTxn(tx *memdb.Txn, idx uint64, acl *structs.ACL) error {
|
||||
// Check that the ID is set
|
||||
if acl.ID == "" {
|
||||
return ErrMissingACLID
|
||||
@ -1272,7 +1361,11 @@ func (s *StateStore) ACLGet(aclID string) (*structs.ACL, error) {
|
||||
func (s *StateStore) ACLList() (uint64, []*structs.ACL, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
return aclListTxn(tx)
|
||||
}
|
||||
|
||||
// aclListTxn is used to list out all of the ACLs in the state store.
|
||||
func aclListTxn(tx *memdb.Txn) (uint64, []*structs.ACL, error) {
|
||||
// Query all of the ACLs in the state store
|
||||
acls, err := tx.Get("acls", "id")
|
||||
if err != nil {
|
||||
@ -1301,17 +1394,18 @@ func (s *StateStore) ACLDelete(idx uint64, aclID string) error {
|
||||
defer tx.Abort()
|
||||
|
||||
// Call the ACL delete
|
||||
if err := s.aclDeleteTxn(idx, aclID, tx); err != nil {
|
||||
if err := aclDeleteTxn(tx, idx, aclID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Defer(func() { s.GetWatchManager("acls").Notify() })
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// aclDeleteTxn is used to delete an ACL from the state store within
|
||||
// an existing transaction.
|
||||
func (s *StateStore) aclDeleteTxn(idx uint64, aclID string, tx *memdb.Txn) error {
|
||||
func aclDeleteTxn(tx *memdb.Txn, idx uint64, aclID string) error {
|
||||
// Look up the existing ACL
|
||||
acl, err := tx.First("acls", "id", aclID)
|
||||
if err != nil {
|
||||
@ -1330,3 +1424,22 @@ func (s *StateStore) aclDeleteTxn(idx uint64, aclID string, tx *memdb.Txn) error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ACLRestore is used when restoring from a snapshot. For general inserts, use
|
||||
// ACLSet.
|
||||
func (s *StateStore) ACLRestore(acl *structs.ACL) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
if err := tx.Insert("acls", acl); err != nil {
|
||||
return fmt.Errorf("failed restoring acl: %s", err)
|
||||
}
|
||||
|
||||
if err := indexUpdateMaxTxn(tx, acl.ModifyIndex, "acls"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Defer(func() { s.GetWatchManager("acls").Notify() })
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
)
|
||||
@ -31,7 +32,7 @@ func testRegisterNode(t *testing.T, s *StateStore, idx uint64, nodeID string) {
|
||||
defer tx.Abort()
|
||||
n, err := tx.First("nodes", "id", nodeID)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err, n)
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if result, ok := n.(*structs.Node); !ok || result.Node != nodeID {
|
||||
t.Fatalf("bad node: %#v", result)
|
||||
@ -107,14 +108,33 @@ func testSetKey(t *testing.T, s *StateStore, idx uint64, key, value string) {
|
||||
|
||||
func TestStateStore_maxIndex(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
testRegisterNode(t, s, 0, "foo")
|
||||
testRegisterNode(t, s, 1, "bar")
|
||||
testRegisterService(t, s, 2, "foo", "consul")
|
||||
|
||||
if max := s.maxIndex("nodes", "services"); max != 2 {
|
||||
t.Fatalf("bad max: %d", max)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateStore_indexUpdateMaxTxn(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
testRegisterNode(t, s, 0, "foo")
|
||||
testRegisterNode(t, s, 1, "bar")
|
||||
|
||||
tx := s.db.Txn(true)
|
||||
if err := indexUpdateMaxTxn(tx, 3, "nodes"); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
if max := s.maxIndex("nodes"); max != 3 {
|
||||
t.Fatalf("bad max: %d", max)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateStore_EnsureNode(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
@ -1415,7 +1435,7 @@ func TestStateStore_SessionCreate_GetSession(t *testing.T) {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if idx := s.maxIndex("sessions"); idx != 2 {
|
||||
t.Fatalf("bad index: %d", err)
|
||||
t.Fatalf("bad index: %s", err)
|
||||
}
|
||||
|
||||
// Retrieve the session again
|
||||
@ -1814,3 +1834,44 @@ func TestStateStore_ACLDelete(t *testing.T) {
|
||||
t.Fatalf("expected nil, got: %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateStore_ACL_Watches(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
ch := make(chan struct{})
|
||||
|
||||
s.GetWatchManager("acls").Start(ch)
|
||||
go func() {
|
||||
if err := s.ACLSet(1, &structs.ACL{ID: "acl1"}); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatalf("watch was not notified")
|
||||
}
|
||||
|
||||
s.GetWatchManager("acls").Start(ch)
|
||||
go func() {
|
||||
if err := s.ACLDelete(2, "acl1"); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatalf("watch was not notified")
|
||||
}
|
||||
|
||||
s.GetWatchManager("acls").Start(ch)
|
||||
go func() {
|
||||
if err := s.ACLRestore(&structs.ACL{ID: "acl1"}); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatalf("watch was not notified")
|
||||
}
|
||||
}
|
||||
|
35
consul/state/watch.go
Normal file
35
consul/state/watch.go
Normal file
@ -0,0 +1,35 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
type WatchManager interface {
|
||||
Start(notifyCh chan struct{})
|
||||
Stop(notifyCh chan struct{})
|
||||
Notify()
|
||||
}
|
||||
|
||||
type FullTableWatch struct {
|
||||
notify NotifyGroup
|
||||
}
|
||||
|
||||
func (w *FullTableWatch) Start(notifyCh chan struct{}) {
|
||||
w.notify.Wait(notifyCh)
|
||||
}
|
||||
|
||||
func (w *FullTableWatch) Stop(notifyCh chan struct{}) {
|
||||
w.notify.Clear(notifyCh)
|
||||
}
|
||||
|
||||
func (w *FullTableWatch) Notify() {
|
||||
w.notify.Notify()
|
||||
}
|
||||
|
||||
func newWatchManagers(schema *memdb.DBSchema) (map[string]WatchManager, error) {
|
||||
watches := make(map[string]WatchManager)
|
||||
for table, _ := range schema.Tables {
|
||||
watches[table] = &FullTableWatch{}
|
||||
}
|
||||
return watches, nil
|
||||
}
|
@ -24,7 +24,6 @@ const (
|
||||
dbTombstone = "tombstones"
|
||||
dbSessions = "sessions"
|
||||
dbSessionChecks = "sessionChecks"
|
||||
dbACLs = "acls"
|
||||
dbMaxMapSize32bit uint64 = 128 * 1024 * 1024 // 128MB maximum size
|
||||
dbMaxMapSize64bit uint64 = 32 * 1024 * 1024 * 1024 // 32GB maximum size
|
||||
dbMaxReaders uint = 4096 // 4K, default is 126
|
||||
@ -59,7 +58,6 @@ type StateStore struct {
|
||||
tombstoneTable *MDBTable
|
||||
sessionTable *MDBTable
|
||||
sessionCheckTable *MDBTable
|
||||
aclTable *MDBTable
|
||||
tables MDBTables
|
||||
watch map[*MDBTable]*NotifyGroup
|
||||
queryTables map[string]MDBTables
|
||||
@ -361,27 +359,9 @@ func (s *StateStore) initialize() error {
|
||||
},
|
||||
}
|
||||
|
||||
s.aclTable = &MDBTable{
|
||||
Name: dbACLs,
|
||||
Indexes: map[string]*MDBIndex{
|
||||
"id": &MDBIndex{
|
||||
Unique: true,
|
||||
Fields: []string{"ID"},
|
||||
},
|
||||
},
|
||||
Decoder: func(buf []byte) interface{} {
|
||||
out := new(structs.ACL)
|
||||
if err := structs.Decode(buf, out); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return out
|
||||
},
|
||||
}
|
||||
|
||||
// Store the set of tables
|
||||
s.tables = []*MDBTable{s.nodeTable, s.serviceTable, s.checkTable,
|
||||
s.kvsTable, s.tombstoneTable, s.sessionTable, s.sessionCheckTable,
|
||||
s.aclTable}
|
||||
s.kvsTable, s.tombstoneTable, s.sessionTable, s.sessionCheckTable}
|
||||
for _, table := range s.tables {
|
||||
table.Env = s.env
|
||||
table.Encoder = encoder
|
||||
@ -408,8 +388,6 @@ func (s *StateStore) initialize() error {
|
||||
"SessionGet": MDBTables{s.sessionTable},
|
||||
"SessionList": MDBTables{s.sessionTable},
|
||||
"NodeSessions": MDBTables{s.sessionTable},
|
||||
"ACLGet": MDBTables{s.aclTable},
|
||||
"ACLList": MDBTables{s.aclTable},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -1945,109 +1923,6 @@ func (s *StateStore) deleteLocks(index uint64, tx *MDBTxn,
|
||||
return nil
|
||||
}
|
||||
|
||||
// ACLSet is used to create or update an ACL entry
|
||||
func (s *StateStore) ACLSet(index uint64, acl *structs.ACL) error {
|
||||
// Check for an ID
|
||||
if acl.ID == "" {
|
||||
return fmt.Errorf("Missing ACL ID")
|
||||
}
|
||||
|
||||
// Start a new txn
|
||||
tx, err := s.tables.StartTxn(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Abort()
|
||||
|
||||
// Look for the existing node
|
||||
res, err := s.aclTable.GetTxn(tx, "id", acl.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch len(res) {
|
||||
case 0:
|
||||
acl.CreateIndex = index
|
||||
acl.ModifyIndex = index
|
||||
case 1:
|
||||
exist := res[0].(*structs.ACL)
|
||||
acl.CreateIndex = exist.CreateIndex
|
||||
acl.ModifyIndex = index
|
||||
default:
|
||||
panic(fmt.Errorf("Duplicate ACL definition. Internal error"))
|
||||
}
|
||||
|
||||
// Insert the ACL
|
||||
if err := s.aclTable.InsertTxn(tx, acl); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Trigger the update notifications
|
||||
if err := s.aclTable.SetLastIndexTxn(tx, index); err != nil {
|
||||
return err
|
||||
}
|
||||
tx.Defer(func() { s.watch[s.aclTable].Notify() })
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// ACLRestore is used to restore an ACL. It should only be used when
|
||||
// doing a restore, otherwise ACLSet should be used.
|
||||
func (s *StateStore) ACLRestore(acl *structs.ACL) error {
|
||||
// Start a new txn
|
||||
tx, err := s.aclTable.StartTxn(false, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Abort()
|
||||
|
||||
if err := s.aclTable.InsertTxn(tx, acl); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.aclTable.SetMaxLastIndexTxn(tx, acl.ModifyIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// ACLGet is used to get an ACL by ID
|
||||
func (s *StateStore) ACLGet(id string) (uint64, *structs.ACL, error) {
|
||||
idx, res, err := s.aclTable.Get("id", id)
|
||||
var d *structs.ACL
|
||||
if len(res) > 0 {
|
||||
d = res[0].(*structs.ACL)
|
||||
}
|
||||
return idx, d, err
|
||||
}
|
||||
|
||||
// ACLList is used to list all the acls
|
||||
func (s *StateStore) ACLList() (uint64, []*structs.ACL, error) {
|
||||
idx, res, err := s.aclTable.Get("id")
|
||||
out := make([]*structs.ACL, len(res))
|
||||
for i, raw := range res {
|
||||
out[i] = raw.(*structs.ACL)
|
||||
}
|
||||
return idx, out, err
|
||||
}
|
||||
|
||||
// ACLDelete is used to remove an ACL
|
||||
func (s *StateStore) ACLDelete(index uint64, id string) error {
|
||||
tx, err := s.tables.StartTxn(false)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("Failed to start txn: %v", err))
|
||||
}
|
||||
defer tx.Abort()
|
||||
|
||||
if n, err := s.aclTable.DeleteTxn(tx, "id", id); err != nil {
|
||||
return err
|
||||
} else if n > 0 {
|
||||
if err := s.aclTable.SetLastIndexTxn(tx, index); err != nil {
|
||||
return err
|
||||
}
|
||||
tx.Defer(func() { s.watch[s.aclTable].Notify() })
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// Snapshot is used to create a point in time snapshot
|
||||
func (s *StateStore) Snapshot() (*StateSnapshot, error) {
|
||||
// Begin a new txn on all tables
|
||||
@ -2128,13 +2003,3 @@ func (s *StateSnapshot) SessionList() ([]*structs.Session, error) {
|
||||
}
|
||||
return out, err
|
||||
}
|
||||
|
||||
// ACLList is used to list all of the ACLs
|
||||
func (s *StateSnapshot) ACLList() ([]*structs.ACL, error) {
|
||||
res, err := s.store.aclTable.GetTxn(s.tx, "id")
|
||||
out := make([]*structs.ACL, len(res))
|
||||
for i, raw := range res {
|
||||
out[i] = raw.(*structs.ACL)
|
||||
}
|
||||
return out, err
|
||||
}
|
||||
|
@ -762,24 +762,6 @@ func TestStoreSnapshot(t *testing.T) {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
a1 := &structs.ACL{
|
||||
ID: generateUUID(),
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
}
|
||||
if err := store.ACLSet(21, a1); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
a2 := &structs.ACL{
|
||||
ID: generateUUID(),
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
}
|
||||
if err := store.ACLSet(22, a2); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Take a snapshot
|
||||
snap, err := store.Snapshot()
|
||||
if err != nil {
|
||||
@ -884,15 +866,6 @@ func TestStoreSnapshot(t *testing.T) {
|
||||
t.Fatalf("Wrong number of sessions with TTL")
|
||||
}
|
||||
|
||||
// Check for an acl
|
||||
acls, err := snap.ACLList()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(acls) != 2 {
|
||||
t.Fatalf("missing acls")
|
||||
}
|
||||
|
||||
// Make some changes!
|
||||
if err := store.EnsureService(23, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8000}); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
@ -918,11 +891,6 @@ func TestStoreSnapshot(t *testing.T) {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Nuke an ACL
|
||||
if err := store.ACLDelete(29, a1.ID); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check snapshot has old values
|
||||
nodes = snap.Nodes()
|
||||
if len(nodes) != 2 {
|
||||
@ -1003,15 +971,6 @@ func TestStoreSnapshot(t *testing.T) {
|
||||
if len(sessions) != 3 {
|
||||
t.Fatalf("missing sessions")
|
||||
}
|
||||
|
||||
// Check for an acl
|
||||
acls, err = snap.ACLList()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(acls) != 2 {
|
||||
t.Fatalf("missing acls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureCheck(t *testing.T) {
|
||||
@ -2880,148 +2839,3 @@ func TestSessionInvalidate_KeyDelete(t *testing.T) {
|
||||
t.Fatalf("Bad: %v", expires)
|
||||
}
|
||||
}
|
||||
|
||||
func TestACLSet_Get(t *testing.T) {
|
||||
store, err := testStateStore()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
idx, out, err := store.ACLGet("1234")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if idx != 0 {
|
||||
t.Fatalf("bad: %v", idx)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
|
||||
a := &structs.ACL{
|
||||
ID: generateUUID(),
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
Rules: "",
|
||||
}
|
||||
if err := store.ACLSet(50, a); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if a.CreateIndex != 50 {
|
||||
t.Fatalf("Bad: %v", a)
|
||||
}
|
||||
if a.ModifyIndex != 50 {
|
||||
t.Fatalf("Bad: %v", a)
|
||||
}
|
||||
if a.ID == "" {
|
||||
t.Fatalf("Bad: %v", a)
|
||||
}
|
||||
|
||||
idx, out, err = store.ACLGet(a.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if idx != 50 {
|
||||
t.Fatalf("bad: %v", idx)
|
||||
}
|
||||
if !reflect.DeepEqual(out, a) {
|
||||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
|
||||
// Update
|
||||
a.Rules = "foo bar baz"
|
||||
if err := store.ACLSet(52, a); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if a.CreateIndex != 50 {
|
||||
t.Fatalf("Bad: %v", a)
|
||||
}
|
||||
if a.ModifyIndex != 52 {
|
||||
t.Fatalf("Bad: %v", a)
|
||||
}
|
||||
|
||||
idx, out, err = store.ACLGet(a.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if idx != 52 {
|
||||
t.Fatalf("bad: %v", idx)
|
||||
}
|
||||
if !reflect.DeepEqual(out, a) {
|
||||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestACLDelete(t *testing.T) {
|
||||
store, err := testStateStore()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
a := &structs.ACL{
|
||||
ID: generateUUID(),
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
Rules: "",
|
||||
}
|
||||
if err := store.ACLSet(50, a); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if err := store.ACLDelete(52, a.ID); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if err := store.ACLDelete(53, a.ID); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
idx, out, err := store.ACLGet(a.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if idx != 52 {
|
||||
t.Fatalf("bad: %v", idx)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestACLList(t *testing.T) {
|
||||
store, err := testStateStore()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
a1 := &structs.ACL{
|
||||
ID: generateUUID(),
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
}
|
||||
if err := store.ACLSet(50, a1); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
a2 := &structs.ACL{
|
||||
ID: generateUUID(),
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
}
|
||||
if err := store.ACLSet(51, a2); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
idx, out, err := store.ACLList()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if idx != 51 {
|
||||
t.Fatalf("bad: %v", idx)
|
||||
}
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user