txn: clean up some state store/acl code

This commit is contained in:
Kyle Havlovitz 2019-01-09 11:59:23 -08:00
parent 995e728ea0
commit c07c5446a8
3 changed files with 40 additions and 37 deletions

View File

@ -1404,9 +1404,12 @@ func vetServiceTxnOp(op *structs.TxnServiceOp, rule acl.Authorizer) error {
Port: service.Port, Port: service.Port,
EnableTagOverride: service.EnableTagOverride, EnableTagOverride: service.EnableTagOverride,
} }
scope := func() map[string]interface{} { var scope func() map[string]interface{}
if op.Verb != api.ServiceDelete && op.Verb != api.ServiceDeleteCAS {
scope = func() map[string]interface{} {
return sentinel.ScopeCatalogUpsert(n, svc) return sentinel.ScopeCatalogUpsert(n, svc)
} }
}
if !rule.ServiceWrite(service.Service, scope) { if !rule.ServiceWrite(service.Service, scope) {
return acl.ErrPermissionDenied return acl.ErrPermissionDenied
} }
@ -1427,19 +1430,24 @@ func vetCheckTxnOp(op *structs.TxnCheckOp, rule acl.Authorizer) error {
Service: op.Check.ServiceID, Service: op.Check.ServiceID,
Tags: op.Check.ServiceTags, Tags: op.Check.ServiceTags,
} }
var scope func() map[string]interface{}
if op.Check.ServiceID == "" { if op.Check.ServiceID == "" {
// Node-level check. // Node-level check.
scope := func() map[string]interface{} { if op.Verb == api.CheckDelete || op.Verb == api.CheckDeleteCAS {
scope = func() map[string]interface{} {
return sentinel.ScopeCatalogUpsert(n, svc) return sentinel.ScopeCatalogUpsert(n, svc)
} }
}
if !rule.NodeWrite(op.Check.Node, scope) { if !rule.NodeWrite(op.Check.Node, scope) {
return acl.ErrPermissionDenied return acl.ErrPermissionDenied
} }
} else { } else {
// Service-level check. // Service-level check.
scope := func() map[string]interface{} { if op.Verb == api.CheckDelete || op.Verb == api.CheckDeleteCAS {
scope = func() map[string]interface{} {
return sentinel.ScopeCatalogUpsert(n, svc) return sentinel.ScopeCatalogUpsert(n, svc)
} }
}
if !rule.ServiceWrite(op.Check.ServiceName, scope) { if !rule.ServiceWrite(op.Check.ServiceName, scope) {
return acl.ErrPermissionDenied return acl.ErrPermissionDenied
} }

View File

@ -377,9 +377,9 @@ func (s *Store) ensureNoNodeWithSimilarNameTxn(tx *memdb.Txn, node *structs.Node
// Returns a bool indicating if a write happened and any error. // Returns a bool indicating if a write happened and any error.
func (s *Store) ensureNodeCASTxn(tx *memdb.Txn, idx uint64, node *structs.Node) (bool, error) { func (s *Store) ensureNodeCASTxn(tx *memdb.Txn, idx uint64, node *structs.Node) (bool, error) {
// Retrieve the existing entry. // Retrieve the existing entry.
existing, err := tx.First("nodes", "id", node.Node) existing, err := getNodeTxn(tx, node.Node)
if err != nil { if err != nil {
return false, fmt.Errorf("node lookup failed: %s", err) return false, err
} }
// Check if the we should do the set. A ModifyIndex of 0 means that // Check if the we should do the set. A ModifyIndex of 0 means that
@ -390,8 +390,7 @@ func (s *Store) ensureNodeCASTxn(tx *memdb.Txn, idx uint64, node *structs.Node)
if node.ModifyIndex != 0 && existing == nil { if node.ModifyIndex != 0 && existing == nil {
return false, nil return false, nil
} }
e, ok := existing.(*structs.Node) if existing != nil && node.ModifyIndex != 0 && node.ModifyIndex != existing.ModifyIndex {
if ok && node.ModifyIndex != 0 && node.ModifyIndex != e.ModifyIndex {
return false, nil return false, nil
} }
@ -612,9 +611,9 @@ func (s *Store) DeleteNode(idx uint64, nodeName string) error {
// the given check, then the call is a noop, otherwise a normal check delete is invoked. // the given check, then the call is a noop, otherwise a normal check delete is invoked.
func (s *Store) deleteNodeCASTxn(tx *memdb.Txn, idx, cidx uint64, nodeName string) (bool, error) { func (s *Store) deleteNodeCASTxn(tx *memdb.Txn, idx, cidx uint64, nodeName string) (bool, error) {
// Look up the node. // Look up the node.
node, err := tx.First("nodes", "id", nodeName) node, err := getNodeTxn(tx, nodeName)
if err != nil { if err != nil {
return false, fmt.Errorf("check lookup failed: %s", err) return false, err
} }
if node == nil { if node == nil {
return false, nil return false, nil
@ -623,9 +622,8 @@ func (s *Store) deleteNodeCASTxn(tx *memdb.Txn, idx, cidx uint64, nodeName strin
// If the existing index does not match the provided CAS // If the existing index does not match the provided CAS
// index arg, then we shouldn't update anything and can safely // index arg, then we shouldn't update anything and can safely
// return early here. // return early here.
existing, ok := node.(*structs.Node) if node.ModifyIndex != cidx {
if !ok || existing.ModifyIndex != cidx { return false, nil
return existing == nil, nil
} }
// Call the actual deletion if the above passed. // Call the actual deletion if the above passed.
@ -1149,7 +1147,7 @@ func (s *Store) NodeService(nodeName string, serviceID string) (uint64, *structs
idx := maxIndexTxn(tx, "services") idx := maxIndexTxn(tx, "services")
// Query the service // Query the service
service, err := s.nodeServiceTxn(tx, nodeName, serviceID) service, err := s.getNodeServiceTxn(tx, nodeName, serviceID)
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("failed querying service for node %q: %s", nodeName, err) return 0, nil, fmt.Errorf("failed querying service for node %q: %s", nodeName, err)
} }
@ -1157,7 +1155,7 @@ func (s *Store) NodeService(nodeName string, serviceID string) (uint64, *structs
return idx, service, nil return idx, service, nil
} }
func (s *Store) nodeServiceTxn(tx *memdb.Txn, nodeName, serviceID string) (*structs.NodeService, error) { func (s *Store) getNodeServiceTxn(tx *memdb.Txn, nodeName, serviceID string) (*structs.NodeService, error) {
// Query the service // Query the service
service, err := tx.First("services", "id", nodeName, serviceID) service, err := tx.First("services", "id", nodeName, serviceID)
if err != nil { if err != nil {
@ -1268,9 +1266,9 @@ func serviceIndexName(name string) string {
// the given service, then the call is a noop, otherwise a normal delete is invoked. // the given service, then the call is a noop, otherwise a normal delete is invoked.
func (s *Store) deleteServiceCASTxn(tx *memdb.Txn, idx, cidx uint64, nodeName, serviceID string) (bool, error) { func (s *Store) deleteServiceCASTxn(tx *memdb.Txn, idx, cidx uint64, nodeName, serviceID string) (bool, error) {
// Look up the service. // Look up the service.
service, err := tx.First("services", "id", nodeName, serviceID) service, err := s.getNodeServiceTxn(tx, nodeName, serviceID)
if err != nil { if err != nil {
return false, fmt.Errorf("check lookup failed: %s", err) return false, fmt.Errorf("service lookup failed: %s", err)
} }
if service == nil { if service == nil {
return false, nil return false, nil
@ -1279,9 +1277,8 @@ func (s *Store) deleteServiceCASTxn(tx *memdb.Txn, idx, cidx uint64, nodeName, s
// If the existing index does not match the provided CAS // If the existing index does not match the provided CAS
// index arg, then we shouldn't update anything and can safely // index arg, then we shouldn't update anything and can safely
// return early here. // return early here.
existing, ok := service.(*structs.ServiceNode) if service.ModifyIndex != cidx {
if !ok || existing.ModifyIndex != cidx { return false, nil
return existing == nil, nil
} }
// Call the actual deletion if the above passed. // Call the actual deletion if the above passed.
@ -1391,7 +1388,7 @@ func (s *Store) updateAllServiceIndexesOfNode(tx *memdb.Txn, idx uint64, nodeID
// Returns a bool indicating if a write happened and any error. // Returns a bool indicating if a write happened and any error.
func (s *Store) ensureCheckCASTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthCheck) (bool, error) { func (s *Store) ensureCheckCASTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthCheck) (bool, error) {
// Retrieve the existing entry. // Retrieve the existing entry.
existing, err := tx.First("checks", "id", hc.Node, string(hc.CheckID)) _, existing, err := s.getNodeCheckTxn(tx, hc.Node, hc.CheckID)
if err != nil { if err != nil {
return false, fmt.Errorf("failed health check lookup: %s", err) return false, fmt.Errorf("failed health check lookup: %s", err)
} }
@ -1404,8 +1401,7 @@ func (s *Store) ensureCheckCASTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthC
if hc.ModifyIndex != 0 && existing == nil { if hc.ModifyIndex != 0 && existing == nil {
return false, nil return false, nil
} }
e, ok := existing.(*structs.HealthCheck) if existing != nil && hc.ModifyIndex != 0 && hc.ModifyIndex != existing.ModifyIndex {
if ok && hc.ModifyIndex != 0 && hc.ModifyIndex != e.ModifyIndex {
return false, nil return false, nil
} }
@ -1533,12 +1529,12 @@ func (s *Store) NodeCheck(nodeName string, checkID types.CheckID) (uint64, *stru
tx := s.db.Txn(false) tx := s.db.Txn(false)
defer tx.Abort() defer tx.Abort()
return s.nodeCheckTxn(tx, nodeName, checkID) return s.getNodeCheckTxn(tx, nodeName, checkID)
} }
// nodeCheckTxn is used as the inner method to handle reading a health check // nodeCheckTxn is used as the inner method to handle reading a health check
// from the state store. // from the state store.
func (s *Store) nodeCheckTxn(tx *memdb.Txn, nodeName string, checkID types.CheckID) (uint64, *structs.HealthCheck, error) { func (s *Store) getNodeCheckTxn(tx *memdb.Txn, nodeName string, checkID types.CheckID) (uint64, *structs.HealthCheck, error) {
// Get the table index. // Get the table index.
idx := maxIndexTxn(tx, "checks") idx := maxIndexTxn(tx, "checks")
@ -1733,7 +1729,7 @@ func (s *Store) DeleteCheck(idx uint64, node string, checkID types.CheckID) erro
// the given check, then the call is a noop, otherwise a normal check delete is invoked. // the given check, then the call is a noop, otherwise a normal check delete is invoked.
func (s *Store) deleteCheckCASTxn(tx *memdb.Txn, idx, cidx uint64, node string, checkID types.CheckID) (bool, error) { func (s *Store) deleteCheckCASTxn(tx *memdb.Txn, idx, cidx uint64, node string, checkID types.CheckID) (bool, error) {
// Try to retrieve the existing health check. // Try to retrieve the existing health check.
hc, err := tx.First("checks", "id", node, string(checkID)) _, hc, err := s.getNodeCheckTxn(tx, node, checkID)
if err != nil { if err != nil {
return false, fmt.Errorf("check lookup failed: %s", err) return false, fmt.Errorf("check lookup failed: %s", err)
} }
@ -1744,9 +1740,8 @@ func (s *Store) deleteCheckCASTxn(tx *memdb.Txn, idx, cidx uint64, node string,
// If the existing index does not match the provided CAS // If the existing index does not match the provided CAS
// index arg, then we shouldn't update anything and can safely // index arg, then we shouldn't update anything and can safely
// return early here. // return early here.
existing, ok := hc.(*structs.HealthCheck) if hc.ModifyIndex != cidx {
if !ok || existing.ModifyIndex != cidx { return false, nil
return existing == nil, nil
} }
// Call the actual deletion if the above passed. // Call the actual deletion if the above passed.

View File

@ -198,14 +198,14 @@ func (s *Store) txnService(tx *memdb.Txn, idx uint64, op *structs.TxnServiceOp)
switch op.Verb { switch op.Verb {
case api.ServiceGet: case api.ServiceGet:
entry, err = s.nodeServiceTxn(tx, op.Node, op.Service.ID) entry, err = s.getNodeServiceTxn(tx, op.Node, op.Service.ID)
if entry == nil && err == nil { if entry == nil && err == nil {
err = fmt.Errorf("service %q on node %q doesn't exist", op.Service.ID, op.Node) err = fmt.Errorf("service %q on node %q doesn't exist", op.Service.ID, op.Node)
} }
case api.ServiceSet: case api.ServiceSet:
err = s.ensureServiceTxn(tx, idx, op.Node, &op.Service) err = s.ensureServiceTxn(tx, idx, op.Node, &op.Service)
entry, err = s.nodeServiceTxn(tx, op.Node, op.Service.ID) entry, err = s.getNodeServiceTxn(tx, op.Node, op.Service.ID)
case api.ServiceCAS: case api.ServiceCAS:
var ok bool var ok bool
@ -214,7 +214,7 @@ func (s *Store) txnService(tx *memdb.Txn, idx uint64, op *structs.TxnServiceOp)
err = fmt.Errorf("failed to set service %q on node %q, index is stale", op.Service.ID, op.Node) err = fmt.Errorf("failed to set service %q on node %q, index is stale", op.Service.ID, op.Node)
break break
} }
entry, err = s.nodeServiceTxn(tx, op.Node, op.Service.ID) entry, err = s.getNodeServiceTxn(tx, op.Node, op.Service.ID)
case api.ServiceDelete: case api.ServiceDelete:
err = s.deleteServiceTxn(tx, idx, op.Node, op.Service.ID) err = s.deleteServiceTxn(tx, idx, op.Node, op.Service.ID)
@ -257,7 +257,7 @@ func (s *Store) txnCheck(tx *memdb.Txn, idx uint64, op *structs.TxnCheckOp) (str
switch op.Verb { switch op.Verb {
case api.CheckGet: case api.CheckGet:
_, entry, err = s.nodeCheckTxn(tx, op.Check.Node, op.Check.CheckID) _, entry, err = s.getNodeCheckTxn(tx, op.Check.Node, op.Check.CheckID)
if entry == nil && err == nil { if entry == nil && err == nil {
err = fmt.Errorf("check %q on node %q doesn't exist", op.Check.CheckID, op.Check.Node) err = fmt.Errorf("check %q on node %q doesn't exist", op.Check.CheckID, op.Check.Node)
} }
@ -265,7 +265,7 @@ func (s *Store) txnCheck(tx *memdb.Txn, idx uint64, op *structs.TxnCheckOp) (str
case api.CheckSet: case api.CheckSet:
err = s.ensureCheckTxn(tx, idx, &op.Check) err = s.ensureCheckTxn(tx, idx, &op.Check)
if err == nil { if err == nil {
_, entry, err = s.nodeCheckTxn(tx, op.Check.Node, op.Check.CheckID) _, entry, err = s.getNodeCheckTxn(tx, op.Check.Node, op.Check.CheckID)
} }
case api.CheckCAS: case api.CheckCAS:
@ -276,7 +276,7 @@ func (s *Store) txnCheck(tx *memdb.Txn, idx uint64, op *structs.TxnCheckOp) (str
err = fmt.Errorf("failed to set check %q on node %q, index is stale", entry.CheckID, entry.Node) err = fmt.Errorf("failed to set check %q on node %q, index is stale", entry.CheckID, entry.Node)
break break
} }
_, entry, err = s.nodeCheckTxn(tx, op.Check.Node, op.Check.CheckID) _, entry, err = s.getNodeCheckTxn(tx, op.Check.Node, op.Check.CheckID)
case api.CheckDelete: case api.CheckDelete:
err = s.deleteCheckTxn(tx, idx, op.Check.Node, op.Check.CheckID) err = s.deleteCheckTxn(tx, idx, op.Check.Node, op.Check.CheckID)