diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index 003a6230c1..c1bca5f6aa 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -163,7 +163,7 @@ func (s *TokenExpirationIndex) FromArgs(args ...interface{}) ([]byte, error) { // ACLTokens is used when saving a snapshot func (s *Snapshot) ACLTokens() (memdb.ResultIterator, error) { - iter, err := s.tx.Get("acl-tokens", "id") + iter, err := s.tx.Get(tableACLTokens, "id") if err != nil { return nil, err } @@ -772,7 +772,7 @@ func (s *Store) ACLTokenBatchGet(ws memdb.WatchSet, accessors []string) (uint64, } } - idx := maxIndexTxn(tx, "acl-tokens") + idx := maxIndexTxn(tx, tableACLTokens) return idx, tokens, nil } @@ -884,7 +884,7 @@ func (s *Store) ACLTokenListUpgradeable(max int) (structs.ACLTokens, <-chan stru tx := s.db.Txn(false) defer tx.Abort() - iter, err := tx.Get("acl-tokens", "needs-upgrade", true) + iter, err := tx.Get(tableACLTokens, "needs-upgrade", true) if err != nil { return nil, nil, fmt.Errorf("failed acl token listing: %v", err) } @@ -906,7 +906,7 @@ func (s *Store) ACLTokenMinExpirationTime(local bool) (time.Time, error) { tx := s.db.Txn(false) defer tx.Abort() - item, err := tx.First("acl-tokens", s.expiresIndexName(local)) + item, err := tx.First(tableACLTokens, s.expiresIndexName(local)) if err != nil { return time.Time{}, fmt.Errorf("failed acl token listing: %v", err) } @@ -926,7 +926,7 @@ func (s *Store) ACLTokenListExpired(local bool, asOf time.Time, max int) (struct tx := s.db.Txn(false) defer tx.Abort() - iter, err := tx.Get("acl-tokens", s.expiresIndexName(local)) + iter, err := tx.Get(tableACLTokens, s.expiresIndexName(local)) if err != nil { return nil, nil, fmt.Errorf("failed acl token listing: %v", err) } diff --git a/agent/consul/state/acl_events.go b/agent/consul/state/acl_events.go index 4c49711c4c..3d219f1fa5 100644 --- a/agent/consul/state/acl_events.go +++ b/agent/consul/state/acl_events.go @@ -16,7 +16,7 @@ func aclChangeUnsubscribeEvent(tx ReadTxn, changes Changes) ([]stream.Event, err for _, change := range changes.Changes { switch change.Table { - case "acl-tokens": + case tableACLTokens: token := changeObject(change).(*structs.ACLToken) secretIDs = append(secretIDs, token.SecretID) diff --git a/agent/consul/state/acl_oss.go b/agent/consul/state/acl_oss.go index 60b9b43585..b480d9f589 100644 --- a/agent/consul/state/acl_oss.go +++ b/agent/consul/state/acl_oss.go @@ -57,12 +57,12 @@ func (s *Store) ACLPolicyUpsertValidateEnterprise(*structs.ACLPolicy, *structs.A func aclTokenInsert(tx WriteTxn, token *structs.ACLToken) error { // insert the token into memdb - if err := tx.Insert("acl-tokens", token); err != nil { + if err := tx.Insert(tableACLTokens, token); err != nil { return fmt.Errorf("failed inserting acl token: %v", err) } // update the overall acl-tokens index - if err := indexUpdateMaxTxn(tx, token.ModifyIndex, "acl-tokens"); err != nil { + if err := indexUpdateMaxTxn(tx, token.ModifyIndex, tableACLTokens); err != nil { return fmt.Errorf("failed updating acl tokens index: %v", err) } @@ -70,48 +70,48 @@ func aclTokenInsert(tx WriteTxn, token *structs.ACLToken) error { } func aclTokenGetFromIndex(tx ReadTxn, id string, index string, entMeta *structs.EnterpriseMeta) (<-chan struct{}, interface{}, error) { - return tx.FirstWatch("acl-tokens", index, id) + return tx.FirstWatch(tableACLTokens, index, id) } func aclTokenListAll(tx ReadTxn, _ *structs.EnterpriseMeta) (memdb.ResultIterator, error) { - return tx.Get("acl-tokens", "id") + return tx.Get(tableACLTokens, "id") } func aclTokenListLocal(tx ReadTxn, _ *structs.EnterpriseMeta) (memdb.ResultIterator, error) { - return tx.Get("acl-tokens", "local", true) + return tx.Get(tableACLTokens, "local", true) } func aclTokenListGlobal(tx ReadTxn, _ *structs.EnterpriseMeta) (memdb.ResultIterator, error) { - return tx.Get("acl-tokens", "local", false) + return tx.Get(tableACLTokens, "local", false) } func aclTokenListByPolicy(tx ReadTxn, policy string, _ *structs.EnterpriseMeta) (memdb.ResultIterator, error) { - return tx.Get("acl-tokens", "policies", policy) + return tx.Get(tableACLTokens, "policies", policy) } func aclTokenListByRole(tx ReadTxn, role string, _ *structs.EnterpriseMeta) (memdb.ResultIterator, error) { - return tx.Get("acl-tokens", "roles", role) + return tx.Get(tableACLTokens, "roles", role) } func aclTokenListByAuthMethod(tx ReadTxn, authMethod string, _, _ *structs.EnterpriseMeta) (memdb.ResultIterator, error) { - return tx.Get("acl-tokens", "authmethod", authMethod) + return tx.Get(tableACLTokens, "authmethod", authMethod) } func aclTokenDeleteWithToken(tx WriteTxn, token *structs.ACLToken, idx uint64) error { // remove the token - if err := tx.Delete("acl-tokens", token); err != nil { + if err := tx.Delete(tableACLTokens, token); err != nil { return fmt.Errorf("failed deleting acl token: %v", err) } // update the overall acl-tokens index - if err := indexUpdateMaxTxn(tx, idx, "acl-tokens"); err != nil { + if err := indexUpdateMaxTxn(tx, idx, tableACLTokens); err != nil { return fmt.Errorf("failed updating acl tokens index: %v", err) } return nil } func aclTokenMaxIndex(tx ReadTxn, _ *structs.ACLToken, entMeta *structs.EnterpriseMeta) uint64 { - return maxIndexTxn(tx, "acl-tokens") + return maxIndexTxn(tx, tableACLTokens) } func aclTokenUpsertValidateEnterprise(tx WriteTxn, token *structs.ACLToken, existing *structs.ACLToken) error { diff --git a/agent/consul/state/acl_schema.go b/agent/consul/state/acl_schema.go index 5ec9eb6643..254b2d5fd7 100644 --- a/agent/consul/state/acl_schema.go +++ b/agent/consul/state/acl_schema.go @@ -42,9 +42,9 @@ func tokensTableSchema() *memdb.TableSchema { Name: indexID, AllowMissing: false, Unique: true, - Indexer: &memdb.StringFieldIndex{ - Field: "SecretID", - Lowercase: false, + Indexer: indexerSingle{ + readIndex: readIndex(indexFromStringCaseSensitive), + writeIndex: writeIndex(indexSecretIDFromACLToken), }, }, indexPolicies: { @@ -324,3 +324,29 @@ func indexAccessorIDFromACLToken(raw interface{}) ([]byte, error) { b.Raw(uuid) return b.Bytes(), nil } + +func indexSecretIDFromACLToken(raw interface{}) ([]byte, error) { + p, ok := raw.(*structs.ACLToken) + if !ok { + return nil, fmt.Errorf("unexpected type %T for structs.ACLToken index", raw) + } + + if p.SecretID == "" { + return nil, errMissingValueForIndex + } + + var b indexBuilder + b.String(p.SecretID) + return b.Bytes(), nil +} + +func indexFromStringCaseSensitive(raw interface{}) ([]byte, error) { + q, ok := raw.(string) + if !ok { + return nil, fmt.Errorf("unexpected type %T for string prefix query", raw) + } + + var b indexBuilder + b.String(q) + return b.Bytes(), nil +} diff --git a/agent/consul/state/acl_test.go b/agent/consul/state/acl_test.go index 0efbf4a069..7962198e3b 100644 --- a/agent/consul/state/acl_test.go +++ b/agent/consul/state/acl_test.go @@ -3737,7 +3737,7 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(4), idx) require.ElementsMatch(t, tokens, res) - require.Equal(t, uint64(4), s.maxIndex("acl-tokens")) + require.Equal(t, uint64(4), s.maxIndex(tableACLTokens)) }() }