From 3f607d9ef09034caf37352f50f3c70b6803bf848 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Wed, 20 May 2020 12:43:33 -0400 Subject: [PATCH] state: use an error to indicate compare failed Errors are values. We can use the error value to identify the 'comparison failed' case which makes the function easier to use and should make it harder to miss handle the error case --- agent/consul/state/catalog.go | 20 +++++++++----------- agent/consul/state/catalog_test.go | 22 ++++++++++------------ agent/consul/state/txn.go | 8 +++++--- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/agent/consul/state/catalog.go b/agent/consul/state/catalog.go index eb9673ca60..82ea05028a 100644 --- a/agent/consul/state/catalog.go +++ b/agent/consul/state/catalog.go @@ -1,6 +1,7 @@ package state import ( + "errors" "fmt" "reflect" "strings" @@ -736,34 +737,31 @@ func (s *Store) EnsureService(idx uint64, node string, svc *structs.NodeService) return nil } +var errCASCompareFailed = errors.New("compare-and-set: comparison failed") + // ensureServiceCASTxn updates a service only if the existing index matches the given index. // Returns a bool indicating if a write happened and any error. -func (s *Store) ensureServiceCASTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeService) (bool, error) { +func (s *Store) ensureServiceCASTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeService) error { // Retrieve the existing service. _, existing, err := firstWatchCompoundWithTxn(tx, "services", "id", &svc.EnterpriseMeta, node, svc.ID) if err != nil { - return false, fmt.Errorf("failed service lookup: %s", err) + return fmt.Errorf("failed service lookup: %s", err) } // Check if the we should do the set. A ModifyIndex of 0 means that // we are doing a set-if-not-exists. if svc.ModifyIndex == 0 && existing != nil { - return false, nil + return errCASCompareFailed } if svc.ModifyIndex != 0 && existing == nil { - return false, nil + return errCASCompareFailed } e, ok := existing.(*structs.ServiceNode) if ok && svc.ModifyIndex != 0 && svc.ModifyIndex != e.ModifyIndex { - return false, nil + return errCASCompareFailed } - // Perform the update. - if err := s.ensureServiceTxn(tx, idx, node, svc); err != nil { - return false, err - } - - return true, nil + return s.ensureServiceTxn(tx, idx, node, svc) } // ensureServiceTxn is used to upsert a service registration within an diff --git a/agent/consul/state/catalog_test.go b/agent/consul/state/catalog_test.go index 2627eab614..b92ada564d 100644 --- a/agent/consul/state/catalog_test.go +++ b/agent/consul/state/catalog_test.go @@ -2,6 +2,11 @@ package state import ( "fmt" + "reflect" + "sort" + "strings" + "testing" + "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/lib" @@ -11,10 +16,6 @@ import ( "github.com/pascaldekloe/goe/verify" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "reflect" - "sort" - "strings" - "testing" ) func makeRandomNodeID(t *testing.T) types.NodeID { @@ -4395,9 +4396,8 @@ func TestStateStore_ensureServiceCASTxn(t *testing.T) { // attempt to update with a 0 index tx := s.db.Txn(true) - update, err := s.ensureServiceCASTxn(tx, 3, "node1", &ns) - require.False(t, update) - require.NoError(t, err) + err := s.ensureServiceCASTxn(tx, 3, "node1", &ns) + require.Equal(t, err, errCASCompareFailed) tx.Commit() // ensure no update happened @@ -4411,9 +4411,8 @@ func TestStateStore_ensureServiceCASTxn(t *testing.T) { ns.ModifyIndex = 99 // attempt to update with a non-matching index tx = s.db.Txn(true) - update, err = s.ensureServiceCASTxn(tx, 4, "node1", &ns) - require.False(t, update) - require.NoError(t, err) + err = s.ensureServiceCASTxn(tx, 4, "node1", &ns) + require.Equal(t, err, errCASCompareFailed) tx.Commit() // ensure no update happened @@ -4427,8 +4426,7 @@ func TestStateStore_ensureServiceCASTxn(t *testing.T) { ns.ModifyIndex = 2 // update with the matching modify index tx = s.db.Txn(true) - update, err = s.ensureServiceCASTxn(tx, 7, "node1", &ns) - require.True(t, update) + err = s.ensureServiceCASTxn(tx, 7, "node1", &ns) require.NoError(t, err) tx.Commit() diff --git a/agent/consul/state/txn.go b/agent/consul/state/txn.go index 06f20681e9..f8c02e25a9 100644 --- a/agent/consul/state/txn.go +++ b/agent/consul/state/txn.go @@ -230,11 +230,13 @@ func (s *Store) txnService(tx *memdb.Txn, idx uint64, op *structs.TxnServiceOp) return newTxnResultFromNodeServiceEntry(entry), err case api.ServiceCAS: - ok, err := s.ensureServiceCASTxn(tx, idx, op.Node, &op.Service) - // TODO: err != nil case is ignored - if !ok && err == nil { + err := s.ensureServiceCASTxn(tx, idx, op.Node, &op.Service) + switch { + case err == errCASCompareFailed: err := fmt.Errorf("failed to set service %q on node %q, index is stale", op.Service.ID, op.Node) return nil, err + case err != nil: + return nil, err } entry, err := s.getNodeServiceTxn(tx, op.Node, op.Service.ID, &op.Service.EnterpriseMeta)