From fea61d629b89c22cfe8ec441af69495b0950d2ba Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Tue, 5 Aug 2014 16:43:57 -0700 Subject: [PATCH] consul: Adding ACLs to the state store --- consul/state_store.go | 147 +++++++++++++++++++++++++++- consul/state_store_test.go | 192 +++++++++++++++++++++++++++++++++++-- consul/structs/structs.go | 11 ++- 3 files changed, 340 insertions(+), 10 deletions(-) diff --git a/consul/state_store.go b/consul/state_store.go index f95b0554e0..39d79c2f2b 100644 --- a/consul/state_store.go +++ b/consul/state_store.go @@ -21,6 +21,7 @@ const ( dbKVS = "kvs" dbSessions = "sessions" dbSessionChecks = "sessionChecks" + dbACLs = "acls" dbMaxMapSize32bit uint64 = 512 * 1024 * 1024 // 512MB maximum size dbMaxMapSize64bit uint64 = 32 * 1024 * 1024 * 1024 // 32GB maximum size ) @@ -53,6 +54,7 @@ type StateStore struct { kvsTable *MDBTable sessionTable *MDBTable sessionCheckTable *MDBTable + aclTable *MDBTable tables MDBTables watch map[*MDBTable]*NotifyGroup queryTables map[string]MDBTables @@ -306,9 +308,26 @@ func (s *StateStore) initialize() error { }, } + s.aclTable = &MDBTable{ + Name: dbACLs, + Indexes: map[string]*MDBIndex{ + "id": &MDBIndex{ + Unique: true, + Fields: []string{"ID"}, + }, + }, + Decoder: func(buf []byte) interface{} { + out := new(structs.ACL) + if err := structs.Decode(buf, out); err != nil { + panic(err) + } + return out + }, + } + // Store the set of tables s.tables = []*MDBTable{s.nodeTable, s.serviceTable, s.checkTable, - s.kvsTable, s.sessionTable, s.sessionCheckTable} + s.kvsTable, s.sessionTable, s.sessionCheckTable, s.aclTable} for _, table := range s.tables { table.Env = s.env table.Encoder = encoder @@ -1249,8 +1268,8 @@ func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error } // Generate a new session ID, verify uniqueness - session.ID = generateUUID() for { + session.ID = generateUUID() res, err = s.sessionTable.GetTxn(tx, "id", session.ID) if err != nil { return err @@ -1346,7 +1365,7 @@ func (s *StateStore) NodeSessions(node string) (uint64, []*structs.Session, erro return idx, out, err } -// SessionDelete is used to destroy a session. +// SessionDestroy is used to destroy a session. func (s *StateStore) SessionDestroy(index uint64, id string) error { tx, err := s.tables.StartTxn(false) if err != nil { @@ -1482,6 +1501,118 @@ func (s *StateStore) invalidateLocks(index uint64, tx *MDBTxn, return nil } +// ACLSet is used to create or update an ACL entry +func (s *StateStore) ACLSet(index uint64, acl *structs.ACL) error { + // Start a new txn + tx, err := s.tables.StartTxn(false) + if err != nil { + return err + } + defer tx.Abort() + + // Generate a new session ID + if acl.ID == "" { + for { + acl.ID = generateUUID() + res, err := s.aclTable.GetTxn(tx, "id", acl.ID) + if err != nil { + return err + } + // Quit if this ID is unique + if len(res) == 0 { + break + } + } + acl.CreateIndex = index + acl.ModifyIndex = index + + } else { + // Look for the existing node + res, err := s.aclTable.GetTxn(tx, "id", acl.ID) + if err != nil { + return err + } + + switch len(res) { + case 0: + return fmt.Errorf("Invalid ACL") + case 1: + exist := res[0].(*structs.ACL) + acl.CreateIndex = exist.CreateIndex + acl.ModifyIndex = index + default: + panic(fmt.Errorf("Duplicate ACL definition. Internal error")) + } + } + + // Insert the ACL + if err := s.aclTable.InsertTxn(tx, acl); err != nil { + return err + } + + // Trigger the update notifications + if err := s.aclTable.SetLastIndexTxn(tx, index); err != nil { + return err + } + tx.Defer(func() { s.watch[s.aclTable].Notify() }) + return tx.Commit() +} + +// ACLRestore is used to restore an ACL. It should only be used when +// doing a restore, otherwise ACLSet should be used. +func (s *StateStore) ACLRestore(acl *structs.ACL) error { + // Start a new txn + tx, err := s.aclTable.StartTxn(false, nil) + if err != nil { + return err + } + defer tx.Abort() + + if err := s.aclTable.InsertTxn(tx, acl); err != nil { + return err + } + return tx.Commit() +} + +// ACLGet is used to get an ACL by ID +func (s *StateStore) ACLGet(id string) (uint64, *structs.ACL, error) { + idx, res, err := s.aclTable.Get("id", id) + var d *structs.ACL + if len(res) > 0 { + d = res[0].(*structs.ACL) + } + return idx, d, err +} + +// ACLList is used to list all the acls +func (s *StateStore) ACLList() (uint64, []*structs.ACL, error) { + idx, res, err := s.aclTable.Get("id") + out := make([]*structs.ACL, len(res)) + for i, raw := range res { + out[i] = raw.(*structs.ACL) + } + return idx, out, err +} + +// ACLDelete is used to remove an ACL +func (s *StateStore) ACLDelete(index uint64, id string) error { + tx, err := s.tables.StartTxn(false) + if err != nil { + panic(fmt.Errorf("Failed to start txn: %v", err)) + } + defer tx.Abort() + + if n, err := s.aclTable.DeleteTxn(tx, "id", id); err != nil { + return err + } else if n > 0 { + if err := s.aclTable.SetLastIndexTxn(tx, index); err != nil { + return err + } + tx.Defer(func() { s.watch[s.aclTable].Notify() }) + } + return tx.Commit() +} + // Snapshot is used to create a point in time snapshot func (s *StateStore) Snapshot() (*StateSnapshot, error) { // Begin a new txn on all tables @@ -1555,3 +1686,13 @@ func (s *StateSnapshot) SessionList() ([]*structs.Session, error) { } return out, err } + +// ACLList is used to list all of the ACLs +func (s *StateSnapshot) ACLList() ([]*structs.ACL, error) { + res, err := s.store.aclTable.GetTxn(s.tx, "id") + out := make([]*structs.ACL, len(res)) + for i, raw := range res { + out[i] = raw.(*structs.ACL) + } + return out, err +} diff --git a/consul/state_store_test.go b/consul/state_store_test.go index a5130131a7..c1705f5df0 100644 --- a/consul/state_store_test.go +++ b/consul/state_store_test.go @@ -652,6 +652,22 @@ func TestStoreSnapshot(t *testing.T) { t.Fatalf("err: %v", err) } + a1 := &structs.ACL{ + Name: "User token", + Type: structs.ACLTypeClient, + } + if err := store.ACLSet(19, a1); err != nil { + t.Fatalf("err: %v", err) + } + + a2 := &structs.ACL{ + Name: "User token", + Type: structs.ACLTypeClient, + } + if err := store.ACLSet(20, a2); err != nil { + t.Fatalf("err: %v", err) + } + // Take a snapshot snap, err := store.Snapshot() if err != nil { @@ -660,7 +676,7 @@ func TestStoreSnapshot(t *testing.T) { defer snap.Close() // Check the last nodes - if idx := snap.LastIndex(); idx != 18 { + if idx := snap.LastIndex(); idx != 20 { t.Fatalf("bad: %v", idx) } @@ -724,14 +740,23 @@ func TestStoreSnapshot(t *testing.T) { t.Fatalf("missing sessions") } + // Check for an acl + acls, err := snap.ACLList() + if err != nil { + t.Fatalf("err: %v", err) + } + if len(acls) != 2 { + t.Fatalf("missing acls") + } + // Make some changes! - if err := store.EnsureService(19, "foo", &structs.NodeService{"db", "db", []string{"slave"}, 8000}); err != nil { + if err := store.EnsureService(21, "foo", &structs.NodeService{"db", "db", []string{"slave"}, 8000}); err != nil { t.Fatalf("err: %v", err) } - if err := store.EnsureService(20, "bar", &structs.NodeService{"db", "db", []string{"master"}, 8000}); err != nil { + if err := store.EnsureService(22, "bar", &structs.NodeService{"db", "db", []string{"master"}, 8000}); err != nil { t.Fatalf("err: %v", err) } - if err := store.EnsureNode(21, structs.Node{"baz", "127.0.0.3"}); err != nil { + if err := store.EnsureNode(23, structs.Node{"baz", "127.0.0.3"}); err != nil { t.Fatalf("err: %v", err) } checkAfter := &structs.HealthCheck{ @@ -741,11 +766,16 @@ func TestStoreSnapshot(t *testing.T) { Status: structs.HealthCritical, ServiceID: "db", } - if err := store.EnsureCheck(22, checkAfter); err != nil { + if err := store.EnsureCheck(24, checkAfter); err != nil { t.Fatalf("err: %v", err) } - if err := store.KVSDelete(23, "/web/b"); err != nil { + if err := store.KVSDelete(25, "/web/b"); err != nil { + t.Fatalf("err: %v", err) + } + + // Nuke an ACL + if err := store.ACLDelete(26, a1.ID); err != nil { t.Fatalf("err: %v", err) } @@ -807,6 +837,15 @@ func TestStoreSnapshot(t *testing.T) { if len(sessions) != 2 { t.Fatalf("missing sessions") } + + // Check for an acl + acls, err = snap.ACLList() + if err != nil { + t.Fatalf("err: %v", err) + } + if len(acls) != 2 { + t.Fatalf("missing acls") + } } func TestEnsureCheck(t *testing.T) { @@ -2117,3 +2156,144 @@ func TestSessionInvalidate_KeyUnlock(t *testing.T) { t.Fatalf("Bad: %v", expires) } } + +func TestACLSet_Get(t *testing.T) { + store, err := testStateStore() + if err != nil { + t.Fatalf("err: %v", err) + } + defer store.Close() + + idx, out, err := store.ACLGet("1234") + if err != nil { + t.Fatalf("err: %v", err) + } + if idx != 0 { + t.Fatalf("bad: %v", idx) + } + if out != nil { + t.Fatalf("bad: %v", out) + } + + a := &structs.ACL{ + Name: "User token", + Type: structs.ACLTypeClient, + Rules: "", + } + if err := store.ACLSet(50, a); err != nil { + t.Fatalf("err: %v", err) + } + if a.CreateIndex != 50 { + t.Fatalf("Bad: %v", a) + } + if a.ModifyIndex != 50 { + t.Fatalf("Bad: %v", a) + } + if a.ID == "" { + t.Fatalf("Bad: %v", a) + } + + idx, out, err = store.ACLGet(a.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if idx != 50 { + t.Fatalf("bad: %v", idx) + } + if !reflect.DeepEqual(out, a) { + t.Fatalf("bad: %v", out) + } + + // Update + a.Rules = "foo bar baz" + if err := store.ACLSet(52, a); err != nil { + t.Fatalf("err: %v", err) + } + if a.CreateIndex != 50 { + t.Fatalf("Bad: %v", a) + } + if a.ModifyIndex != 52 { + t.Fatalf("Bad: %v", a) + } + + idx, out, err = store.ACLGet(a.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if idx != 52 { + t.Fatalf("bad: %v", idx) + } + if !reflect.DeepEqual(out, a) { + t.Fatalf("bad: %v", out) + } +} + +func TestACLDelete(t *testing.T) { + store, err := testStateStore() + if err != nil { + t.Fatalf("err: %v", err) + } + defer store.Close() + + a := &structs.ACL{ + Name: "User token", + Type: structs.ACLTypeClient, + Rules: "", + } + if err := store.ACLSet(50, a); err != nil { + t.Fatalf("err: %v", err) + } + + if err := store.ACLDelete(52, a.ID); err != nil { + t.Fatalf("err: %v", err) + } + if err := store.ACLDelete(53, a.ID); err != nil { + t.Fatalf("err: %v", err) + } + + idx, out, err := store.ACLGet(a.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if idx != 52 { + t.Fatalf("bad: %v", idx) + } + if out != nil { + t.Fatalf("bad: %v", out) + } +} + +func TestACLList(t *testing.T) { + store, err := testStateStore() + if err != nil { + t.Fatalf("err: %v", err) + } + defer store.Close() + + a1 := &structs.ACL{ + Name: "User token", + Type: structs.ACLTypeClient, + } + if err := store.ACLSet(50, a1); err != nil { + t.Fatalf("err: %v", err) + } + + a2 := &structs.ACL{ + Name: "User token", + Type: structs.ACLTypeClient, + } + if err := store.ACLSet(51, a2); err != nil { + t.Fatalf("err: %v", err) + } + + idx, out, err := store.ACLList() + if err != nil { + t.Fatalf("err: %v", err) + } + if idx != 51 { + t.Fatalf("bad: %v", idx) + } + if len(out) != 2 { + t.Fatalf("bad: %v", out) + } +} diff --git a/consul/structs/structs.go b/consul/structs/structs.go index 64995704e3..27cffd0b28 100644 --- a/consul/structs/structs.go +++ b/consul/structs/structs.go @@ -20,6 +20,7 @@ const ( DeregisterRequestType KVSRequestType SessionRequestType + ACLRequestType ) const ( @@ -32,6 +33,15 @@ const ( HealthCritical = "critical" ) +const ( + // Client tokens have rules applied + ACLTypeClient = "client" + + // Management tokens have an always allow policy. + // They are used for token management. + ACLTypeManagement = "management" +) + const ( // MaxLockDelay provides a maximum LockDelay value for // a session. Any value above this will not be respected. @@ -421,7 +431,6 @@ type ACL struct { Name string Type string Rules string - TTL time.Duration } type ACLs []*ACL