Preserve PeeringState on upsert (#13666)

Fixes a bug where if the generate token is called twice, the second call upserts the zero-value (undefined) of PeeringState.
This commit is contained in:
Chris S. Kim 2022-07-25 14:37:56 -04:00 committed by GitHub
parent 5786309356
commit 73a84f256f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 144 additions and 82 deletions

View File

@ -213,6 +213,13 @@ func (s *Store) PeeringWrite(idx uint64, p *pbpeering.Peering) error {
return fmt.Errorf("cannot write to peering that is marked for deletion")
}
if p.State == pbpeering.PeeringState_UNDEFINED {
p.State = existing.State
}
// TODO(peering): Confirm behavior when /peering/token is called more than once.
// We may need to avoid clobbering existing values.
p.ImportedServiceCount = existing.ImportedServiceCount
p.ExportedServiceCount = existing.ExportedServiceCount
p.CreateIndex = existing.CreateIndex
p.ModifyIndex = idx
} else {

View File

@ -12,6 +12,7 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/structs"
@ -148,6 +149,71 @@ func TestHTTP_Peering_GenerateToken(t *testing.T) {
})
}
// Test for GenerateToken calls at various points in a peer's lifecycle
func TestHTTP_Peering_GenerateToken_EdgeCases(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
a := NewTestAgent(t, "")
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
body := &pbpeering.GenerateTokenRequest{
PeerName: "peering-a",
}
bodyBytes, err := json.Marshal(body)
require.NoError(t, err)
getPeering := func(t *testing.T) *api.Peering {
t.Helper()
// Check state of peering
req, err := http.NewRequest("GET", "/v1/peering/peering-a", bytes.NewReader(bodyBytes))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
var p *api.Peering
require.NoError(t, json.NewDecoder(resp.Body).Decode(&p))
return p
}
{
// Call once
req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader(bodyBytes))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
// Assertions tested in TestHTTP_Peering_GenerateToken
}
if !t.Run("generate token called again", func(t *testing.T) {
before := getPeering(t)
require.Equal(t, api.PeeringStatePending, before.State)
// Call again
req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader(bodyBytes))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
after := getPeering(t)
assert.NotEqual(t, before.ModifyIndex, after.ModifyIndex)
// blank out modify index so we can compare rest of struct
before.ModifyIndex, after.ModifyIndex = 0, 0
assert.Equal(t, before, after)
}) {
t.FailNow()
}
}
func TestHTTP_Peering_Establish(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")

View File

@ -211,6 +211,53 @@ func (s *Server) GenerateToken(
return nil, err
}
var peering *pbpeering.Peering
// This loop ensures at most one retry in the case of a race condition.
for canRetry := true; canRetry; canRetry = false {
peering, err = s.getExistingPeering(req.PeerName, entMeta.PartitionOrDefault())
if err != nil {
return nil, err
}
if peering == nil {
id, err := lib.GenerateUUID(s.Backend.CheckPeeringUUID)
if err != nil {
return resp, err
}
peering = &pbpeering.Peering{
ID: id,
Name: req.PeerName,
Meta: req.Meta,
// PartitionOrEmpty is used to avoid writing "default" in OSS.
Partition: entMeta.PartitionOrEmpty(),
}
} else {
// validate that this peer name is not being used as a dialer already
if err := validatePeer(peering, false); err != nil {
return nil, err
}
}
writeReq := pbpeering.PeeringWriteRequest{
Peering: peering,
}
if err := s.Backend.PeeringWrite(&writeReq); err != nil {
// There's a possible race where two servers call Generate Token at the
// same time with the same peer name for the first time. They both
// generate an ID and try to insert and only one wins. This detects the
// collision and forces the loser to discard its generated ID and use
// the one from the other server.
if strings.Contains(err.Error(), "A peering already exists with the name") {
// retry to fetch existing peering
continue
}
return nil, fmt.Errorf("failed to write peering: %w", err)
}
// write succeeded, break loop early
break
}
ca, err := s.Backend.GetAgentCACertificates()
if err != nil {
return nil, err
@ -227,57 +274,6 @@ func (s *Server) GenerateToken(
}
}
peeringOrNil, err := s.getExistingPeering(req.PeerName, entMeta.PartitionOrDefault())
if err != nil {
return nil, err
}
// validate that this peer name is not being used as a dialer already
if err = validatePeer(peeringOrNil, false); err != nil {
return nil, err
}
canRetry := true
RETRY_ONCE:
id, err := s.getExistingOrCreateNewPeerID(req.PeerName, entMeta.PartitionOrDefault())
if err != nil {
return nil, err
}
writeReq := pbpeering.PeeringWriteRequest{
Peering: &pbpeering.Peering{
ID: id,
Name: req.PeerName,
Meta: req.Meta,
// PartitionOrEmpty is used to avoid writing "default" in OSS.
Partition: entMeta.PartitionOrEmpty(),
},
}
if err := s.Backend.PeeringWrite(&writeReq); err != nil {
// There's a possible race where two servers call Generate Token at the
// same time with the same peer name for the first time. They both
// generate an ID and try to insert and only one wins. This detects the
// collision and forces the loser to discard its generated ID and use
// the one from the other server.
if canRetry && strings.Contains(err.Error(), "A peering already exists with the name") {
canRetry = false
goto RETRY_ONCE
}
return nil, fmt.Errorf("failed to write peering: %w", err)
}
q := state.Query{
Value: strings.ToLower(req.PeerName),
EnterpriseMeta: *entMeta,
}
_, peering, err := s.Backend.Store().PeeringRead(nil, q)
if err != nil {
return nil, err
}
if peering == nil {
return nil, fmt.Errorf("peering was deleted while token generation request was in flight")
}
tok := structs.PeeringToken{
// Store the UUID so that we can do a global search when handling inbound streams.
PeerID: peering.ID,
@ -345,24 +341,24 @@ func (s *Server) Establish(
return nil, err
}
peeringOrNil, err := s.getExistingPeering(req.PeerName, entMeta.PartitionOrDefault())
peering, err := s.getExistingPeering(req.PeerName, entMeta.PartitionOrDefault())
if err != nil {
return nil, err
}
// validate that this peer name is not being used as an acceptor already
if err = validatePeer(peeringOrNil, true); err != nil {
return nil, err
}
var id string
if peeringOrNil != nil {
id = peeringOrNil.ID
} else {
if peering == nil {
id, err = lib.GenerateUUID(s.Backend.CheckPeeringUUID)
if err != nil {
return nil, err
}
} else {
id = peering.ID
}
// validate that this peer name is not being used as an acceptor already
if err := validatePeer(peering, true); err != nil {
return nil, err
}
// convert ServiceAddress values to strings
@ -392,10 +388,10 @@ func (s *Server) Establish(
Partition: entMeta.PartitionOrEmpty(),
},
}
if err = s.Backend.PeeringWrite(writeReq); err != nil {
if err := s.Backend.PeeringWrite(writeReq); err != nil {
return nil, fmt.Errorf("failed to write peering: %w", err)
}
// resp.Status == 0
// TODO(peering): low prio: consider adding response details
return resp, nil
}
@ -564,10 +560,19 @@ func (s *Server) PeeringWrite(ctx context.Context, req *pbpeering.PeeringWriteRe
return nil, fmt.Errorf("missing required peering body")
}
id, err := s.getExistingOrCreateNewPeerID(req.Peering.Name, entMeta.PartitionOrDefault())
var id string
peering, err := s.getExistingPeering(req.Peering.Name, entMeta.PartitionOrDefault())
if err != nil {
return nil, err
}
if peering == nil {
id, err = lib.GenerateUUID(s.Backend.CheckPeeringUUID)
if err != nil {
return nil, err
}
} else {
id = peering.ID
}
req.Peering.ID = id
err = s.Backend.PeeringWrite(req)
@ -759,22 +764,6 @@ func (s *Server) TrustBundleListByService(ctx context.Context, req *pbpeering.Tr
return &pbpeering.TrustBundleListByServiceResponse{Index: idx, Bundles: bundles}, nil
}
func (s *Server) getExistingOrCreateNewPeerID(peerName, partition string) (string, error) {
peeringOrNil, err := s.getExistingPeering(peerName, partition)
if err != nil {
return "", err
}
if peeringOrNil != nil {
return peeringOrNil.ID, nil
}
id, err := lib.GenerateUUID(s.Backend.CheckPeeringUUID)
if err != nil {
return "", err
}
return id, nil
}
func (s *Server) getExistingPeering(peerName, partition string) (*pbpeering.Peering, error) {
q := state.Query{
Value: strings.ToLower(peerName),
@ -793,9 +782,9 @@ func (s *Server) getExistingPeering(peerName, partition string) (*pbpeering.Peer
//
// We define a DIALER as a peering that has server addresses (or a peering that is created via the Establish endpoint)
// Conversely, we define an ACCEPTOR as a peering that is created via the GenerateToken endpoint
func validatePeer(peering *pbpeering.Peering, allowedToDial bool) error {
if peering != nil && peering.ShouldDial() != allowedToDial {
if allowedToDial {
func validatePeer(peering *pbpeering.Peering, shouldDial bool) error {
if peering != nil && peering.ShouldDial() != shouldDial {
if shouldDial {
return fmt.Errorf("cannot create peering with name: %q; there is an existing peering expecting to be dialed", peering.Name)
} else {
return fmt.Errorf("cannot create peering with name: %q; there is already an established peering", peering.Name)