Adds fine-grained watch support to ACL endpoints.

This commit is contained in:
James Phillips 2017-01-24 00:00:06 -08:00
parent eaa8fde298
commit ec90404df0
No known key found for this signature in database
GPG Key ID: 77183E682AC5FC11
10 changed files with 55 additions and 37 deletions

View File

@ -62,7 +62,7 @@ func (s *Server) aclLocalFault(id string) (string, string, error) {
// Query the state store.
state := s.fsm.State()
_, acl, err := state.ACLGet(id)
_, acl, err := state.ACLGet(nil, id)
if err != nil {
return "", "", err
}

View File

@ -7,6 +7,7 @@ import (
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/consul/structs"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-uuid"
)
@ -108,7 +109,7 @@ func (a *ACL) Apply(args *structs.ACLRequest, reply *string) error {
return err
}
_, acl, err := state.ACLGet(args.ACL.ID)
_, acl, err := state.ACLGet(nil, args.ACL.ID)
if err != nil {
a.srv.logger.Printf("[ERR] consul.acl: ACL lookup failed: %v", err)
return err
@ -146,11 +147,10 @@ func (a *ACL) Get(args *structs.ACLSpecificRequest,
// Get the local state
state := a.srv.fsm.State()
return a.srv.blockingRPC(&args.QueryOptions,
return a.srv.blockingQuery(&args.QueryOptions,
&reply.QueryMeta,
state.GetQueryWatch("ACLGet"),
func() error {
index, acl, err := state.ACLGet(args.ACL)
func(ws memdb.WatchSet) error {
index, acl, err := state.ACLGet(ws, args.ACL)
if err != nil {
return err
}
@ -226,11 +226,10 @@ func (a *ACL) List(args *structs.DCSpecificRequest,
// Get the local state
state := a.srv.fsm.State()
return a.srv.blockingRPC(&args.QueryOptions,
return a.srv.blockingQuery(&args.QueryOptions,
&reply.QueryMeta,
state.GetQueryWatch("ACLList"),
func() error {
index, acls, err := state.ACLList()
func(ws memdb.WatchSet) error {
index, acls, err := state.ACLList(ws)
if err != nil {
return err
}

View File

@ -41,7 +41,7 @@ func TestACLEndpoint_Apply(t *testing.T) {
// Verify
state := s1.fsm.State()
_, s, err := state.ACLGet(out)
_, s, err := state.ACLGet(nil, out)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -63,7 +63,7 @@ func TestACLEndpoint_Apply(t *testing.T) {
}
// Verify
_, s, err = state.ACLGet(id)
_, s, err = state.ACLGet(nil, id)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -182,7 +182,7 @@ func TestACLEndpoint_Apply_CustomID(t *testing.T) {
// Verify
state := s1.fsm.State()
_, s, err := state.ACLGet(out)
_, s, err := state.ACLGet(nil, out)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -139,7 +139,7 @@ func reconcileACLs(local, remote structs.ACLs, lastRemoteIndex uint64) structs.A
// FetchLocalACLs returns the ACLs in the local state store.
func (s *Server) fetchLocalACLs() (structs.ACLs, error) {
_, local, err := s.fsm.State().ACLList()
_, local, err := s.fsm.State().ACLList(nil)
if err != nil {
return nil, err
}

View File

@ -364,11 +364,11 @@ func TestACLReplication(t *testing.T) {
}
checkSame := func() (bool, error) {
index, remote, err := s1.fsm.State().ACLList()
index, remote, err := s1.fsm.State().ACLList(nil)
if err != nil {
return false, err
}
_, local, err := s2.fsm.State().ACLList()
_, local, err := s2.fsm.State().ACLList(nil)
if err != nil {
return false, err
}

View File

@ -688,14 +688,14 @@ func TestACL_Replication(t *testing.T) {
// Wait for replication to occur.
testutil.WaitForResult(func() (bool, error) {
_, acl, err := s2.fsm.State().ACLGet(id)
_, acl, err := s2.fsm.State().ACLGet(nil, id)
if err != nil {
return false, err
}
if acl == nil {
return false, nil
}
_, acl, err = s3.fsm.State().ACLGet(id)
_, acl, err = s3.fsm.State().ACLGet(nil, id)
if err != nil {
return false, err
}

View File

@ -512,7 +512,7 @@ func TestFSM_SnapshotRestore(t *testing.T) {
}
// Verify ACL is restored
_, a, err := fsm2.state.ACLGet(acl.ID)
_, a, err := fsm2.state.ACLGet(nil, acl.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1053,7 +1053,7 @@ func TestFSM_ACL_Set_Delete(t *testing.T) {
// Get the ACL
id := resp.(string)
_, acl, err := fsm.state.ACLGet(id)
_, acl, err := fsm.state.ACLGet(nil, id)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1089,7 +1089,7 @@ func TestFSM_ACL_Set_Delete(t *testing.T) {
t.Fatalf("resp: %v", resp)
}
_, acl, err = fsm.state.ACLGet(id)
_, acl, err = fsm.state.ACLGet(nil, id)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -185,7 +185,7 @@ func (s *Server) initializeACL() error {
// Look for the anonymous token
state := s.fsm.State()
_, acl, err := state.ACLGet(anonymousToken)
_, acl, err := state.ACLGet(nil, anonymousToken)
if err != nil {
return fmt.Errorf("failed to get anonymous token: %v", err)
}
@ -214,7 +214,7 @@ func (s *Server) initializeACL() error {
}
// Look for the master token
_, acl, err = state.ACLGet(master)
_, acl, err = state.ACLGet(nil, master)
if err != nil {
return fmt.Errorf("failed to get master token: %v", err)
}

View File

@ -80,18 +80,20 @@ func (s *StateStore) aclSetTxn(tx *memdb.Txn, idx uint64, acl *structs.ACL) erro
}
// ACLGet is used to look up an existing ACL by ID.
func (s *StateStore) ACLGet(aclID string) (uint64, *structs.ACL, error) {
func (s *StateStore) ACLGet(ws memdb.WatchSet, aclID string) (uint64, *structs.ACL, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexTxn(tx, s.getWatchTables("ACLGet")...)
idx := maxIndexTxn(tx, "acls")
// Query for the existing ACL
acl, err := tx.First("acls", "id", aclID)
watchCh, acl, err := tx.FirstWatch("acls", "id", aclID)
if err != nil {
return 0, nil, fmt.Errorf("failed acl lookup: %s", err)
}
ws.Add(watchCh)
if acl != nil {
return idx, acl.(*structs.ACL), nil
}
@ -99,15 +101,15 @@ func (s *StateStore) ACLGet(aclID string) (uint64, *structs.ACL, error) {
}
// ACLList is used to list out all of the ACLs in the state store.
func (s *StateStore) ACLList() (uint64, structs.ACLs, error) {
func (s *StateStore) ACLList(ws memdb.WatchSet) (uint64, structs.ACLs, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexTxn(tx, s.getWatchTables("ACLList")...)
idx := maxIndexTxn(tx, "acls")
// Return the ACLs.
acls, err := s.aclListTxn(tx)
acls, err := s.aclListTxn(tx, ws)
if err != nil {
return 0, nil, fmt.Errorf("failed acl lookup: %s", err)
}
@ -116,16 +118,17 @@ func (s *StateStore) ACLList() (uint64, structs.ACLs, error) {
// aclListTxn is used to list out all of the ACLs in the state store. This is a
// function vs. a method so it can be called from the snapshotter.
func (s *StateStore) aclListTxn(tx *memdb.Txn) (structs.ACLs, error) {
func (s *StateStore) aclListTxn(tx *memdb.Txn, ws memdb.WatchSet) (structs.ACLs, error) {
// Query all of the ACLs in the state store
acls, err := tx.Get("acls", "id")
iter, err := tx.Get("acls", "id")
if err != nil {
return nil, fmt.Errorf("failed acl lookup: %s", err)
}
ws.Add(iter.WatchCh())
// Go over all of the ACLs and build the response
var result structs.ACLs
for acl := acls.Next(); acl != nil; acl = acls.Next() {
for acl := iter.Next(); acl != nil; acl = iter.Next() {
a := acl.(*structs.ACL)
result = append(result, a)
}

View File

@ -5,13 +5,15 @@ import (
"testing"
"github.com/hashicorp/consul/consul/structs"
"github.com/hashicorp/go-memdb"
)
func TestStateStore_ACLSet_ACLGet(t *testing.T) {
s := testStateStore(t)
// Querying ACLs with no results returns nil
idx, res, err := s.ACLGet("nope")
ws := memdb.NewWatchSet()
idx, res, err := s.ACLGet(ws, "nope")
if idx != 0 || res != nil || err != nil {
t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err)
}
@ -20,6 +22,9 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) {
if err := s.ACLSet(1, &structs.ACL{}); err == nil {
t.Fatalf("expected %#v, got: %#v", ErrMissingACLID, err)
}
if watchFired(ws) {
t.Fatalf("bad")
}
// Index is not updated if nothing is saved
if idx := s.maxIndex("acls"); idx != 0 {
@ -36,6 +41,9 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) {
if err := s.ACLSet(1, acl); err != nil {
t.Fatalf("err: %s", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Check that the index was updated
if idx := s.maxIndex("acls"); idx != 1 {
@ -43,7 +51,8 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) {
}
// Retrieve the ACL again
idx, result, err := s.ACLGet("acl1")
ws = memdb.NewWatchSet()
idx, result, err := s.ACLGet(ws, "acl1")
if err != nil {
t.Fatalf("err: %s", err)
}
@ -76,6 +85,9 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) {
if err := s.ACLSet(2, acl); err != nil {
t.Fatalf("err: %s", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Index was updated
if idx := s.maxIndex("acls"); idx != 2 {
@ -102,7 +114,8 @@ func TestStateStore_ACLList(t *testing.T) {
s := testStateStore(t)
// Listing when no ACLs exist returns nil
idx, res, err := s.ACLList()
ws := memdb.NewWatchSet()
idx, res, err := s.ACLList(ws)
if idx != 0 || res != nil || err != nil {
t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err)
}
@ -133,9 +146,12 @@ func TestStateStore_ACLList(t *testing.T) {
t.Fatalf("err: %s", err)
}
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Query the ACLs
idx, res, err = s.ACLList()
idx, res, err = s.ACLList(nil)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -255,7 +271,7 @@ func TestStateStore_ACL_Snapshot_Restore(t *testing.T) {
restore.Commit()
// Read the restored ACLs back out and verify that they match.
idx, res, err := s.ACLList()
idx, res, err := s.ACLList(nil)
if err != nil {
t.Fatalf("err: %s", err)
}