mirror of
https://github.com/status-im/consul.git
synced 2025-02-23 02:48:19 +00:00
Adds fine-grained watch support to ACL endpoints.
This commit is contained in:
parent
eaa8fde298
commit
ec90404df0
@ -62,7 +62,7 @@ func (s *Server) aclLocalFault(id string) (string, string, error) {
|
||||
|
||||
// Query the state store.
|
||||
state := s.fsm.State()
|
||||
_, acl, err := state.ACLGet(id)
|
||||
_, acl, err := state.ACLGet(nil, id)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
@ -108,7 +109,7 @@ func (a *ACL) Apply(args *structs.ACLRequest, reply *string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
_, acl, err := state.ACLGet(args.ACL.ID)
|
||||
_, acl, err := state.ACLGet(nil, args.ACL.ID)
|
||||
if err != nil {
|
||||
a.srv.logger.Printf("[ERR] consul.acl: ACL lookup failed: %v", err)
|
||||
return err
|
||||
@ -146,11 +147,10 @@ func (a *ACL) Get(args *structs.ACLSpecificRequest,
|
||||
|
||||
// Get the local state
|
||||
state := a.srv.fsm.State()
|
||||
return a.srv.blockingRPC(&args.QueryOptions,
|
||||
return a.srv.blockingQuery(&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
state.GetQueryWatch("ACLGet"),
|
||||
func() error {
|
||||
index, acl, err := state.ACLGet(args.ACL)
|
||||
func(ws memdb.WatchSet) error {
|
||||
index, acl, err := state.ACLGet(ws, args.ACL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -226,11 +226,10 @@ func (a *ACL) List(args *structs.DCSpecificRequest,
|
||||
|
||||
// Get the local state
|
||||
state := a.srv.fsm.State()
|
||||
return a.srv.blockingRPC(&args.QueryOptions,
|
||||
return a.srv.blockingQuery(&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
state.GetQueryWatch("ACLList"),
|
||||
func() error {
|
||||
index, acls, err := state.ACLList()
|
||||
func(ws memdb.WatchSet) error {
|
||||
index, acls, err := state.ACLList(ws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ func TestACLEndpoint_Apply(t *testing.T) {
|
||||
|
||||
// Verify
|
||||
state := s1.fsm.State()
|
||||
_, s, err := state.ACLGet(out)
|
||||
_, s, err := state.ACLGet(nil, out)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
@ -63,7 +63,7 @@ func TestACLEndpoint_Apply(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify
|
||||
_, s, err = state.ACLGet(id)
|
||||
_, s, err = state.ACLGet(nil, id)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
@ -182,7 +182,7 @@ func TestACLEndpoint_Apply_CustomID(t *testing.T) {
|
||||
|
||||
// Verify
|
||||
state := s1.fsm.State()
|
||||
_, s, err := state.ACLGet(out)
|
||||
_, s, err := state.ACLGet(nil, out)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ func reconcileACLs(local, remote structs.ACLs, lastRemoteIndex uint64) structs.A
|
||||
|
||||
// FetchLocalACLs returns the ACLs in the local state store.
|
||||
func (s *Server) fetchLocalACLs() (structs.ACLs, error) {
|
||||
_, local, err := s.fsm.State().ACLList()
|
||||
_, local, err := s.fsm.State().ACLList(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -364,11 +364,11 @@ func TestACLReplication(t *testing.T) {
|
||||
}
|
||||
|
||||
checkSame := func() (bool, error) {
|
||||
index, remote, err := s1.fsm.State().ACLList()
|
||||
index, remote, err := s1.fsm.State().ACLList(nil)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
_, local, err := s2.fsm.State().ACLList()
|
||||
_, local, err := s2.fsm.State().ACLList(nil)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -688,14 +688,14 @@ func TestACL_Replication(t *testing.T) {
|
||||
|
||||
// Wait for replication to occur.
|
||||
testutil.WaitForResult(func() (bool, error) {
|
||||
_, acl, err := s2.fsm.State().ACLGet(id)
|
||||
_, acl, err := s2.fsm.State().ACLGet(nil, id)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if acl == nil {
|
||||
return false, nil
|
||||
}
|
||||
_, acl, err = s3.fsm.State().ACLGet(id)
|
||||
_, acl, err = s3.fsm.State().ACLGet(nil, id)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -512,7 +512,7 @@ func TestFSM_SnapshotRestore(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify ACL is restored
|
||||
_, a, err := fsm2.state.ACLGet(acl.ID)
|
||||
_, a, err := fsm2.state.ACLGet(nil, acl.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
@ -1053,7 +1053,7 @@ func TestFSM_ACL_Set_Delete(t *testing.T) {
|
||||
|
||||
// Get the ACL
|
||||
id := resp.(string)
|
||||
_, acl, err := fsm.state.ACLGet(id)
|
||||
_, acl, err := fsm.state.ACLGet(nil, id)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
@ -1089,7 +1089,7 @@ func TestFSM_ACL_Set_Delete(t *testing.T) {
|
||||
t.Fatalf("resp: %v", resp)
|
||||
}
|
||||
|
||||
_, acl, err = fsm.state.ACLGet(id)
|
||||
_, acl, err = fsm.state.ACLGet(nil, id)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -185,7 +185,7 @@ func (s *Server) initializeACL() error {
|
||||
|
||||
// Look for the anonymous token
|
||||
state := s.fsm.State()
|
||||
_, acl, err := state.ACLGet(anonymousToken)
|
||||
_, acl, err := state.ACLGet(nil, anonymousToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get anonymous token: %v", err)
|
||||
}
|
||||
@ -214,7 +214,7 @@ func (s *Server) initializeACL() error {
|
||||
}
|
||||
|
||||
// Look for the master token
|
||||
_, acl, err = state.ACLGet(master)
|
||||
_, acl, err = state.ACLGet(nil, master)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get master token: %v", err)
|
||||
}
|
||||
|
@ -80,18 +80,20 @@ func (s *StateStore) aclSetTxn(tx *memdb.Txn, idx uint64, acl *structs.ACL) erro
|
||||
}
|
||||
|
||||
// ACLGet is used to look up an existing ACL by ID.
|
||||
func (s *StateStore) ACLGet(aclID string) (uint64, *structs.ACL, error) {
|
||||
func (s *StateStore) ACLGet(ws memdb.WatchSet, aclID string) (uint64, *structs.ACL, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the table index.
|
||||
idx := maxIndexTxn(tx, s.getWatchTables("ACLGet")...)
|
||||
idx := maxIndexTxn(tx, "acls")
|
||||
|
||||
// Query for the existing ACL
|
||||
acl, err := tx.First("acls", "id", aclID)
|
||||
watchCh, acl, err := tx.FirstWatch("acls", "id", aclID)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed acl lookup: %s", err)
|
||||
}
|
||||
ws.Add(watchCh)
|
||||
|
||||
if acl != nil {
|
||||
return idx, acl.(*structs.ACL), nil
|
||||
}
|
||||
@ -99,15 +101,15 @@ func (s *StateStore) ACLGet(aclID string) (uint64, *structs.ACL, error) {
|
||||
}
|
||||
|
||||
// ACLList is used to list out all of the ACLs in the state store.
|
||||
func (s *StateStore) ACLList() (uint64, structs.ACLs, error) {
|
||||
func (s *StateStore) ACLList(ws memdb.WatchSet) (uint64, structs.ACLs, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the table index.
|
||||
idx := maxIndexTxn(tx, s.getWatchTables("ACLList")...)
|
||||
idx := maxIndexTxn(tx, "acls")
|
||||
|
||||
// Return the ACLs.
|
||||
acls, err := s.aclListTxn(tx)
|
||||
acls, err := s.aclListTxn(tx, ws)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed acl lookup: %s", err)
|
||||
}
|
||||
@ -116,16 +118,17 @@ func (s *StateStore) ACLList() (uint64, structs.ACLs, error) {
|
||||
|
||||
// aclListTxn is used to list out all of the ACLs in the state store. This is a
|
||||
// function vs. a method so it can be called from the snapshotter.
|
||||
func (s *StateStore) aclListTxn(tx *memdb.Txn) (structs.ACLs, error) {
|
||||
func (s *StateStore) aclListTxn(tx *memdb.Txn, ws memdb.WatchSet) (structs.ACLs, error) {
|
||||
// Query all of the ACLs in the state store
|
||||
acls, err := tx.Get("acls", "id")
|
||||
iter, err := tx.Get("acls", "id")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed acl lookup: %s", err)
|
||||
}
|
||||
ws.Add(iter.WatchCh())
|
||||
|
||||
// Go over all of the ACLs and build the response
|
||||
var result structs.ACLs
|
||||
for acl := acls.Next(); acl != nil; acl = acls.Next() {
|
||||
for acl := iter.Next(); acl != nil; acl = iter.Next() {
|
||||
a := acl.(*structs.ACL)
|
||||
result = append(result, a)
|
||||
}
|
||||
|
@ -5,13 +5,15 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
func TestStateStore_ACLSet_ACLGet(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
// Querying ACLs with no results returns nil
|
||||
idx, res, err := s.ACLGet("nope")
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, res, err := s.ACLGet(ws, "nope")
|
||||
if idx != 0 || res != nil || err != nil {
|
||||
t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err)
|
||||
}
|
||||
@ -20,6 +22,9 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) {
|
||||
if err := s.ACLSet(1, &structs.ACL{}); err == nil {
|
||||
t.Fatalf("expected %#v, got: %#v", ErrMissingACLID, err)
|
||||
}
|
||||
if watchFired(ws) {
|
||||
t.Fatalf("bad")
|
||||
}
|
||||
|
||||
// Index is not updated if nothing is saved
|
||||
if idx := s.maxIndex("acls"); idx != 0 {
|
||||
@ -36,6 +41,9 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) {
|
||||
if err := s.ACLSet(1, acl); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if !watchFired(ws) {
|
||||
t.Fatalf("bad")
|
||||
}
|
||||
|
||||
// Check that the index was updated
|
||||
if idx := s.maxIndex("acls"); idx != 1 {
|
||||
@ -43,7 +51,8 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) {
|
||||
}
|
||||
|
||||
// Retrieve the ACL again
|
||||
idx, result, err := s.ACLGet("acl1")
|
||||
ws = memdb.NewWatchSet()
|
||||
idx, result, err := s.ACLGet(ws, "acl1")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@ -76,6 +85,9 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) {
|
||||
if err := s.ACLSet(2, acl); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if !watchFired(ws) {
|
||||
t.Fatalf("bad")
|
||||
}
|
||||
|
||||
// Index was updated
|
||||
if idx := s.maxIndex("acls"); idx != 2 {
|
||||
@ -102,7 +114,8 @@ func TestStateStore_ACLList(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
// Listing when no ACLs exist returns nil
|
||||
idx, res, err := s.ACLList()
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, res, err := s.ACLList(ws)
|
||||
if idx != 0 || res != nil || err != nil {
|
||||
t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err)
|
||||
}
|
||||
@ -133,9 +146,12 @@ func TestStateStore_ACLList(t *testing.T) {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
if !watchFired(ws) {
|
||||
t.Fatalf("bad")
|
||||
}
|
||||
|
||||
// Query the ACLs
|
||||
idx, res, err = s.ACLList()
|
||||
idx, res, err = s.ACLList(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@ -255,7 +271,7 @@ func TestStateStore_ACL_Snapshot_Restore(t *testing.T) {
|
||||
restore.Commit()
|
||||
|
||||
// Read the restored ACLs back out and verify that they match.
|
||||
idx, res, err := s.ACLList()
|
||||
idx, res, err := s.ACLList(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user