diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index 55b5926607..02b8f03ea1 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -266,6 +266,55 @@ func (s *Store) resolveTokenPolicyLinks(tx *memdb.Txn, token *structs.ACLToken, return nil } +// fixupTokenPolicyLinks is to be used when retrieving tokens from memdb. The policy links could have gotten +// stale when a linked policy was deleted or renamed. This will correct them and generate a newly allocated +// token only when fixes are needed. If the policy links are still accurate then we just return the original +// token. +func (s *Store) fixupTokenPolicyLinks(tx *memdb.Txn, original *structs.ACLToken) (*structs.ACLToken, error) { + owned := false + token := original + + cloneToken := func(t *structs.ACLToken, copyNumLinks int) *structs.ACLToken { + clone := *t + clone.Policies = make([]structs.ACLTokenPolicyLink, copyNumLinks) + copy(clone.Policies, t.Policies[:copyNumLinks]) + return &clone + } + + for linkIndex, link := range original.Policies { + if link.ID == "" { + return nil, fmt.Errorf("Detected corrupted token within the state store - missing policy link ID") + } + + policy, err := s.getPolicyWithTxn(tx, nil, link.ID, "id") + + if err != nil { + return nil, err + } + + if policy == nil { + if !owned { + // clone the token as we cannot touch the original + token = cloneToken(original, linkIndex) + owned = true + } + // if already owned then we just don't append it. + } else if policy.Name != link.Name { + if !owned { + token = cloneToken(original, linkIndex) + owned = true + } + + // append the corrected policy + token.Policies = append(token.Policies, structs.ACLTokenPolicyLink{ID: link.ID, Name: policy.Name}) + } else if owned { + token.Policies = append(token.Policies, link) + } + } + + return token, nil +} + // ACLTokenSet is used to insert an ACL rule into the state store. func (s *Store) ACLTokenSet(idx uint64, token *structs.ACLToken, legacy bool) error { tx := s.db.Txn(true) @@ -446,8 +495,8 @@ func (s *Store) aclTokenGetTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index st ws.Add(watchCh) if rawToken != nil { - token := rawToken.(*structs.ACLToken) - if err := s.resolveTokenPolicyLinks(tx, token, true); err != nil { + token, err := s.fixupTokenPolicyLinks(tx, rawToken.(*structs.ACLToken)) + if err != nil { return nil, err } return token, nil @@ -501,8 +550,9 @@ func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy strin var result structs.ACLTokens for raw := iter.Next(); raw != nil; raw = iter.Next() { - token := raw.(*structs.ACLToken) - if err := s.resolveTokenPolicyLinks(tx, token, true); err != nil { + token, err := s.fixupTokenPolicyLinks(tx, raw.(*structs.ACLToken)) + + if err != nil { return 0, nil, err } result = append(result, token) diff --git a/agent/consul/state/acl_test.go b/agent/consul/state/acl_test.go index 6328ef32e7..547c107ecd 100644 --- a/agent/consul/state/acl_test.go +++ b/agent/consul/state/acl_test.go @@ -801,6 +801,132 @@ func TestStateStore_ACLToken_List(t *testing.T) { } } +func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { + // This test wants to ensure a couple of things. + // + // 1. Doing a token list/get should never modify the data + // tracked by memdb + // 2. Token list/get operations should return an accurate set + // of policy links + t.Parallel() + s := testACLTokensStateStore(t) + + // the policy specific token + token := &structs.ACLToken{ + AccessorID: "47eea4da-bda1-48a6-901c-3e36d2d9262f", + SecretID: "548bdb8e-c0d6-477b-bcc4-67fb836e9e61", + Policies: []structs.ACLTokenPolicyLink{ + structs.ACLTokenPolicyLink{ + ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + }, + }, + } + + require.NoError(t, s.ACLTokenSet(2, token, false)) + + _, retrieved, err := s.ACLTokenGetByAccessor(nil, token.AccessorID) + require.NoError(t, err) + // pointer equality check these should be identical + require.True(t, token == retrieved) + require.Len(t, retrieved.Policies, 1) + require.Equal(t, "node-read", retrieved.Policies[0].Name) + + // rename the policy + renamed := &structs.ACLPolicy{ + ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + Name: "node-read-renamed", + Description: "Allows reading all node information", + Rules: `node_prefix "" { policy = "read" }`, + Syntax: acl.SyntaxCurrent, + } + renamed.SetHash(true) + require.NoError(t, s.ACLPolicySet(3, renamed)) + + // retrieve the token again + _, retrieved, err = s.ACLTokenGetByAccessor(nil, token.AccessorID) + require.NoError(t, err) + // pointer equality check these should be different if we cloned things appropriately + require.True(t, token != retrieved) + require.Len(t, retrieved.Policies, 1) + require.Equal(t, "node-read-renamed", retrieved.Policies[0].Name) + + // list tokens without stale links + _, tokens, err := s.ACLTokenList(nil, true, true, "") + require.NoError(t, err) + + found := false + for _, tok := range tokens { + if tok.AccessorID == token.AccessorID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, tok != token) + require.Len(t, tok.Policies, 1) + require.Equal(t, "node-read-renamed", tok.Policies[0].Name) + found = true + break + } + } + require.True(t, found) + + // batch get without stale links + _, tokens, err = s.ACLTokenBatchGet(nil, []string{token.AccessorID}) + require.NoError(t, err) + + found = false + for _, tok := range tokens { + if tok.AccessorID == token.AccessorID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, tok != token) + require.Len(t, tok.Policies, 1) + require.Equal(t, "node-read-renamed", tok.Policies[0].Name) + found = true + break + } + } + require.True(t, found) + + // delete the policy + require.NoError(t, s.ACLPolicyDeleteByID(4, "a0625e95-9b3e-42de-a8d6-ceef5b6f3286")) + + // retrieve the token again + _, retrieved, err = s.ACLTokenGetByAccessor(nil, token.AccessorID) + require.NoError(t, err) + // pointer equality check these should be different if we cloned things appropriately + require.True(t, token != retrieved) + require.Len(t, retrieved.Policies, 0) + + // list tokens without stale links + _, tokens, err = s.ACLTokenList(nil, true, true, "") + require.NoError(t, err) + + found = false + for _, tok := range tokens { + if tok.AccessorID == token.AccessorID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, tok != token) + require.Len(t, tok.Policies, 0) + found = true + break + } + } + require.True(t, found) + + // batch get without stale links + _, tokens, err = s.ACLTokenBatchGet(nil, []string{token.AccessorID}) + require.NoError(t, err) + + found = false + for _, tok := range tokens { + if tok.AccessorID == token.AccessorID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, tok != token) + require.Len(t, tok.Policies, 0) + found = true + break + } + } + require.True(t, found) +} + func TestStateStore_ACLToken_Delete(t *testing.T) { t.Parallel() @@ -1344,6 +1470,29 @@ func TestStateStore_ACLPolicy_Delete(t *testing.T) { func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { s := testStateStore(t) + policies := structs.ACLPolicies{ + &structs.ACLPolicy{ + ID: "ca1fc52c-3676-4050-82ed-ca223e38b2c9", + Name: "policy1", + Description: "policy1", + Rules: `node_prefix "" { policy = "read" }`, + Syntax: acl.SyntaxCurrent, + }, + &structs.ACLPolicy{ + ID: "7b70fa0f-58cd-412d-93c3-a0f17bb19a3e", + Name: "policy2", + Description: "policy2", + Rules: `acl = "read"`, + Syntax: acl.SyntaxCurrent, + }, + } + + for _, policy := range policies { + policy.SetHash(true) + } + + require.NoError(t, s.ACLPolicyBatchSet(2, policies)) + tokens := structs.ACLTokens{ &structs.ACLToken{ AccessorID: "68016c3d-835b-450c-a6f9-75db9ba740be", @@ -1411,6 +1560,9 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { } restore.Commit() + // need to ensure we have the policies or else the links will be removed + require.NoError(t, s.ACLPolicyBatchSet(2, policies)) + // Read the restored ACLs back out and verify that they match. idx, res, err := s.ACLTokenList(nil, true, true, "") require.NoError(t, err)