Add support for streaming CA roots to peers (#13260)

Sender watches for changes to CA roots and sends
them through the replication stream. Receiver saves
CA roots to tablePeeringTrustBundle
This commit is contained in:
Chris S. Kim 2022-05-26 15:24:09 -04:00 committed by GitHub
parent c052c17d20
commit 6d3bea7129
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 493 additions and 171 deletions

View File

@ -129,6 +129,11 @@ func (a *peeringApply) PeeringTerminateByID(req *pbpeering.PeeringTerminateByIDR
return err
}
func (a *peeringApply) PeeringTrustBundleWrite(req *pbpeering.PeeringTrustBundleWriteRequest) error {
_, err := a.srv.raftApplyProtobuf(structs.PeeringTrustBundleWriteType, req)
return err
}
func (a *peeringApply) CatalogRegister(req *structs.RegisterRequest) error {
_, err := a.srv.leaderRaftApply("Catalog.Register", structs.RegisterRequestType, req)
return err

View File

@ -5,6 +5,7 @@ import (
"fmt"
"strings"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/go-hclog"
"google.golang.org/genproto/googleapis/rpc/code"
@ -29,20 +30,17 @@ import (
request a resync operation.
*/
// pushService response handles sending exported service instance updates to the peer cluster.
// makeServiceResponse handles preparing exported service instance updates to the peer cluster.
// Each cache.UpdateEvent will contain all instances for a service name.
// If there are no instances in the event, we consider that to be a de-registration.
func pushServiceResponse(
func makeServiceResponse(
logger hclog.Logger,
stream BidirectionalStream,
status *lockableStreamStatus,
update cache.UpdateEvent,
) error {
csn, ok := update.Result.(*pbservice.IndexedCheckServiceNodes)
if !ok {
logger.Error(fmt.Sprintf("invalid type for response: %T, expected *pbservice.IndexedCheckServiceNodes", update.Result))
// Skip this update to avoid locking up peering due to a bad service update.
) *pbpeering.ReplicationMessage {
any, csn, err := marshalToProtoAny[*pbservice.IndexedCheckServiceNodes](update.Result)
if err != nil {
// Log the error and skip this response to avoid locking up peering due to a bad update event.
logger.Error("failed to marshal", "error", err)
return nil
}
@ -72,21 +70,10 @@ func pushServiceResponse(
},
},
}
logTraceSend(logger, resp)
if err := stream.Send(resp); err != nil {
status.trackSendError(err.Error())
return fmt.Errorf("failed to send to stream: %v", err)
}
return nil
return resp
}
// If there are nodes in the response, we push them as an UPSERT operation.
any, err := ptypes.MarshalAny(csn)
if err != nil {
// Log the error and skip this response to avoid locking up peering due to a bad update event.
logger.Error("failed to marshal service endpoints", "error", err)
return nil
}
resp := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Response_{
Response: &pbpeering.ReplicationMessage_Response{
@ -99,16 +86,58 @@ func pushServiceResponse(
},
},
}
logTraceSend(logger, resp)
if err := stream.Send(resp); err != nil {
status.trackSendError(err.Error())
return fmt.Errorf("failed to send to stream: %v", err)
}
return nil
return resp
}
func (s *Service) processResponse(peerName string, partition string, resp *pbpeering.ReplicationMessage_Response) (*pbpeering.ReplicationMessage, error) {
if resp.ResourceURL != pbpeering.TypeURLService {
func makeCARootsResponse(
logger hclog.Logger,
update cache.UpdateEvent,
) *pbpeering.ReplicationMessage {
any, _, err := marshalToProtoAny[*pbpeering.PeeringTrustBundle](update.Result)
if err != nil {
// Log the error and skip this response to avoid locking up peering due to a bad update event.
logger.Error("failed to marshal", "error", err)
return nil
}
resp := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Response_{
Response: &pbpeering.ReplicationMessage_Response{
ResourceURL: pbpeering.TypeURLRoots,
// TODO(peering): Nonce management
Nonce: "",
ResourceID: "roots",
Operation: pbpeering.ReplicationMessage_Response_UPSERT,
Resource: any,
},
},
}
return resp
}
// marshalToProtoAny takes any input and returns:
// the protobuf.Any type, the asserted T type, and any errors
// during marshalling or type assertion.
// `in` MUST be of type T or it returns an error.
func marshalToProtoAny[T proto.Message](in any) (*anypb.Any, T, error) {
typ, ok := in.(T)
if !ok {
var outType T
return nil, typ, fmt.Errorf("input type is not %T: %T", outType, in)
}
any, err := ptypes.MarshalAny(typ)
if err != nil {
return nil, typ, err
}
return any, typ, nil
}
func (s *Service) processResponse(
peerName string,
partition string,
resp *pbpeering.ReplicationMessage_Response,
) (*pbpeering.ReplicationMessage, error) {
if !pbpeering.KnownTypeURL(resp.ResourceURL) {
err := fmt.Errorf("received response for unknown resource type %q", resp.ResourceURL)
return makeReply(
resp.ResourceURL,
@ -186,6 +215,15 @@ func (s *Service) handleUpsert(
}
return s.handleUpsertService(peerName, partition, sn, csn)
case pbpeering.TypeURLRoots:
roots := &pbpeering.PeeringTrustBundle{}
if err := ptypes.UnmarshalAny(resource, roots); err != nil {
return fmt.Errorf("failed to unmarshal resource: %w", err)
}
return s.handleUpsertRoots(peerName, partition, roots)
default:
return fmt.Errorf("unexpected resourceURL: %s", resourceURL)
}
@ -249,6 +287,21 @@ func (s *Service) handleUpsertService(
return nil
}
func (s *Service) handleUpsertRoots(
peerName string,
partition string,
trustBundle *pbpeering.PeeringTrustBundle,
) error {
// We override the partition and peer name so that the trust bundle gets stored
// in the importing partition with a reference to the peer it was imported from.
trustBundle.Partition = partition
trustBundle.PeerName = peerName
req := &pbpeering.PeeringTrustBundleWriteRequest{
PeeringTrustBundle: trustBundle,
}
return s.Backend.Apply().PeeringTrustBundleWrite(req)
}
func (s *Service) handleDelete(
peerName string,
partition string,

View File

@ -118,7 +118,7 @@ type Store interface {
ExportedServicesForPeer(ws memdb.WatchSet, peerID string) (uint64, *structs.ExportedServiceList, error)
PeeringsForService(ws memdb.WatchSet, serviceName string, entMeta acl.EnterpriseMeta) (uint64, []*pbpeering.Peering, error)
ServiceDump(ws memdb.WatchSet, kind structs.ServiceKind, useKind bool, entMeta *acl.EnterpriseMeta, peerName string) (uint64, structs.CheckServiceNodes, error)
CAConfig(memdb.WatchSet) (uint64, *structs.CAConfiguration, error)
CAConfig(ws memdb.WatchSet) (uint64, *structs.CAConfiguration, error)
AbandonCh() <-chan struct{}
}
@ -127,6 +127,7 @@ type Apply interface {
PeeringWrite(req *pbpeering.PeeringWriteRequest) error
PeeringDelete(req *pbpeering.PeeringDeleteRequest) error
PeeringTerminateByID(req *pbpeering.PeeringTerminateByIDRequest) error
PeeringTrustBundleWrite(req *pbpeering.PeeringTrustBundleWriteRequest) error
CatalogRegister(req *structs.RegisterRequest) error
}
@ -469,7 +470,7 @@ func (s *Service) StreamResources(stream pbpeering.PeeringService_StreamResource
if req.Nonce != "" {
return grpcstatus.Error(codes.InvalidArgument, "initial subscription request must not contain a nonce")
}
if req.ResourceURL != pbpeering.TypeURLService {
if !pbpeering.KnownTypeURL(req.ResourceURL) {
return grpcstatus.Error(codes.InvalidArgument, fmt.Sprintf("subscription request to unknown resource URL: %s", req.ResourceURL))
}
@ -680,18 +681,30 @@ func (s *Service) HandleStream(req HandleStreamRequest) error {
}
case update := <-subCh:
var resp *pbpeering.ReplicationMessage
switch {
case strings.HasPrefix(update.CorrelationID, subExportedService),
strings.HasPrefix(update.CorrelationID, subExportedProxyService):
if err := pushServiceResponse(logger, req.Stream, status, update); err != nil {
return fmt.Errorf("failed to push data for %q: %w", update.CorrelationID, err)
}
resp = makeServiceResponse(logger, update)
case strings.HasPrefix(update.CorrelationID, subMeshGateway):
//TODO(Peering): figure out how to sync this separately
// TODO(Peering): figure out how to sync this separately
case update.CorrelationID == subCARoot:
resp = makeCARootsResponse(logger, update)
default:
logger.Warn("unrecognized update type from subscription manager: " + update.CorrelationID)
continue
}
if resp == nil {
continue
}
logTraceSend(logger, resp)
if err := req.Stream.Send(resp); err != nil {
status.trackSendError(err.Error())
return fmt.Errorf("failed to push data for %q: %w", update.CorrelationID, err)
}
}
}
}

View File

@ -11,15 +11,14 @@ import (
"testing"
"time"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-uuid"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
gogrpc "google.golang.org/grpc"
"google.golang.org/protobuf/types/known/anypb"
"github.com/hashicorp/consul/agent/consul/state"
grpc "github.com/hashicorp/consul/agent/grpc/private"
"github.com/hashicorp/consul/agent/grpc/private/resolver"
"github.com/hashicorp/consul/api"
@ -725,6 +724,12 @@ func Test_StreamHandler_UpsertServices(t *testing.T) {
_, err = client.Recv()
require.NoError(t, err)
// Receive first roots replication message
receiveRoots, err := client.Recv()
require.NoError(t, err)
require.NotNil(t, receiveRoots.GetResponse())
require.Equal(t, pbpeering.TypeURLRoots, receiveRoots.GetResponse().ResourceURL)
remoteEntMeta := structs.DefaultEnterpriseMetaInPartition("remote-partition")
localEntMeta := acl.DefaultEnterpriseMeta()
localPeerName := "my-peer"
@ -752,7 +757,7 @@ func Test_StreamHandler_UpsertServices(t *testing.T) {
pbCSN.Nodes = append(pbCSN.Nodes, pbservice.NewCheckServiceNodeFromStructs(&csn))
}
any, err := ptypes.MarshalAny(pbCSN)
any, err := anypb.New(pbCSN)
require.NoError(t, err)
tc.msg.Resource = any

View File

@ -112,7 +112,7 @@ func TestStreamResources_Server_LeaderBecomesFollower(t *testing.T) {
peerID := p.ID
// Set the initial roots and CA configuration.
_ = writeInitialRootsAndCA(t, store)
_, _ = writeInitialRootsAndCA(t, store)
// Receive a subscription from a peer
sub := &pbpeering.ReplicationMessage{
@ -130,6 +130,11 @@ func TestStreamResources_Server_LeaderBecomesFollower(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, msg)
receiveRoots, err := client.Recv()
require.NoError(t, err)
require.NotNil(t, receiveRoots.GetResponse())
require.Equal(t, pbpeering.TypeURLRoots, receiveRoots.GetResponse().ResourceURL)
input2 := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
@ -280,19 +285,6 @@ func TestStreamResources_Server_Terminate(t *testing.T) {
}
srv.streams.timeNow = it.Now
client := NewMockClient(context.Background())
errCh := make(chan error, 1)
client.ErrCh = errCh
go func() {
// Pass errors from server handler into ErrCh so that they can be seen by the client on Recv().
// This matches gRPC's behavior when an error is returned by a server.
if err := srv.StreamResources(client.ReplicationStream); err != nil {
errCh <- err
}
}()
p := writeInitiatedPeering(t, store, 1, "my-peer")
var (
peerID = p.ID // for Send
@ -300,19 +292,17 @@ func TestStreamResources_Server_Terminate(t *testing.T) {
)
// Set the initial roots and CA configuration.
_ = writeInitialRootsAndCA(t, store)
_, _ = writeInitialRootsAndCA(t, store)
// Receive a subscription from a peer
sub := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
PeerID: peerID,
ResourceURL: pbpeering.TypeURLService,
},
},
}
err := client.Send(sub)
client := makeClient(t, srv, peerID, remotePeerID)
// TODO(peering): test fails if we don't drain the stream with this call because the
// server gets blocked sending the termination message. Figure out a way to let
// messages queue and filter replication messages.
receiveRoots, err := client.Recv()
require.NoError(t, err)
require.NotNil(t, receiveRoots.GetResponse())
require.Equal(t, pbpeering.TypeURLRoots, receiveRoots.GetResponse().ResourceURL)
testutil.RunStep(t, "new stream gets tracked", func(t *testing.T) {
retry.Run(t, func(r *retry.R) {
@ -322,20 +312,6 @@ func TestStreamResources_Server_Terminate(t *testing.T) {
})
})
// Receive subscription to my-peer-B's resources
receivedSub, err := client.Recv()
require.NoError(t, err)
expect := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
PeerID: remotePeerID,
},
},
}
prototest.AssertDeepEqual(t, expect, receivedSub)
testutil.RunStep(t, "terminate the stream", func(t *testing.T) {
done := srv.ConnectedStreams()[peerID]
close(done)
@ -348,7 +324,7 @@ func TestStreamResources_Server_Terminate(t *testing.T) {
receivedTerm, err := client.Recv()
require.NoError(t, err)
expect = &pbpeering.ReplicationMessage{
expect := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Terminated_{
Terminated: &pbpeering.ReplicationMessage_Terminated{},
},
@ -375,12 +351,8 @@ func TestStreamResources_Server_StreamTracker(t *testing.T) {
}
srv.streams.timeNow = it.Now
client := NewMockClient(context.Background())
errCh := make(chan error, 1)
go func() {
errCh <- srv.StreamResources(client.ReplicationStream)
}()
// Set the initial roots and CA configuration.
_, rootA := writeInitialRootsAndCA(t, store)
p := writeInitiatedPeering(t, store, 1, "my-peer")
var (
@ -388,20 +360,7 @@ func TestStreamResources_Server_StreamTracker(t *testing.T) {
remotePeerID = p.PeerID // for Recv
)
// Set the initial roots and CA configuration.
_ = writeInitialRootsAndCA(t, store)
// Receive a subscription from a peer
sub := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
PeerID: peerID,
ResourceURL: pbpeering.TypeURLService,
},
},
}
err := client.Send(sub)
require.NoError(t, err)
client := makeClient(t, srv, peerID, remotePeerID)
testutil.RunStep(t, "new stream gets tracked", func(t *testing.T) {
retry.Run(t, func(r *retry.R) {
@ -411,22 +370,6 @@ func TestStreamResources_Server_StreamTracker(t *testing.T) {
})
})
testutil.RunStep(t, "client receives initial subscription", func(t *testing.T) {
ack, err := client.Recv()
require.NoError(t, err)
expectAck := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
PeerID: remotePeerID,
Nonce: "",
},
},
}
prototest.AssertDeepEqual(t, expectAck, ack)
})
var sequence uint64
var lastSendSuccess time.Time
@ -516,6 +459,24 @@ func TestStreamResources_Server_StreamTracker(t *testing.T) {
require.NoError(t, err)
sequence++
expectRoots := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Response_{
Response: &pbpeering.ReplicationMessage_Response{
ResourceURL: pbpeering.TypeURLRoots,
ResourceID: "roots",
Resource: makeAnyPB(t, &pbpeering.PeeringTrustBundle{
TrustDomain: connect.TestTrustDomain,
RootPEMs: []string{rootA.RootCert},
}),
Operation: pbpeering.ReplicationMessage_Response_UPSERT,
},
},
}
roots, err := client.Recv()
require.NoError(t, err)
prototest.AssertDeepEqual(t, expectRoots, roots)
ack, err := client.Recv()
require.NoError(t, err)
@ -629,14 +590,6 @@ func TestStreamResources_Server_StreamTracker(t *testing.T) {
require.Equal(r, expect, status)
})
})
select {
case err := <-errCh:
// Client disconnect is not an error, but should make the handler return.
require.NoError(t, err)
case <-time.After(50 * time.Millisecond):
t.Fatalf("timed out waiting for handler to finish")
}
}
func TestStreamResources_Server_ServiceUpdates(t *testing.T) {
@ -652,13 +605,9 @@ func testStreamResources_Server_ServiceUpdates(t *testing.T, disableMeshGateways
// Create a peering
var lastIdx uint64 = 1
p := writeInitiatedPeering(t, store, lastIdx, "my-peering")
var (
peerID = p.ID // for Send
remotePeerID = p.PeerID // for Recv
)
// Set the initial roots and CA configuration.
_ = writeInitialRootsAndCA(t, store)
_, _ = writeInitialRootsAndCA(t, store)
srv := NewService(
testutil.Logger(t),
@ -670,44 +619,7 @@ func testStreamResources_Server_ServiceUpdates(t *testing.T, disableMeshGateways
store: store,
pub: publisher,
})
client := NewMockClient(context.Background())
errCh := make(chan error, 1)
client.ErrCh = errCh
go func() {
// Pass errors from server handler into ErrCh so that they can be seen by the client on Recv().
// This matches gRPC's behavior when an error is returned by a server.
if err := srv.StreamResources(client.ReplicationStream); err != nil {
errCh <- err
}
}()
// Issue a services subscription to server
init := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
PeerID: peerID,
ResourceURL: pbpeering.TypeURLService,
},
},
}
require.NoError(t, client.Send(init))
// Receive a services subscription from server
receivedSub, err := client.Recv()
require.NoError(t, err)
expect := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
PeerID: remotePeerID,
},
},
}
prototest.AssertDeepEqual(t, expect, receivedSub)
client := makeClient(t, srv, p.ID, p.PeerID)
// Register a service that is not yet exported
mysql := &structs.CheckServiceNode{
@ -751,6 +663,10 @@ func testStreamResources_Server_ServiceUpdates(t *testing.T, disableMeshGateways
require.NoError(t, store.EnsureConfigEntry(lastIdx, entry))
expectReplEvents(t, client,
func(t *testing.T, msg *pbpeering.ReplicationMessage) {
require.Equal(t, pbpeering.TypeURLRoots, msg.GetResponse().ResourceURL)
// Roots tested in TestStreamResources_Server_CARootUpdates
},
func(t *testing.T, msg *pbpeering.ReplicationMessage) {
require.Equal(t, pbpeering.TypeURLService, msg.GetResponse().ResourceURL)
require.Equal(t, mongoSN, msg.GetResponse().ResourceID)
@ -820,7 +736,7 @@ func testStreamResources_Server_ServiceUpdates(t *testing.T, disableMeshGateways
},
}
lastIdx++
err = store.EnsureConfigEntry(lastIdx, entry)
err := store.EnsureConfigEntry(lastIdx, entry)
require.NoError(t, err)
retry.Run(t, func(r *retry.R) {
@ -834,7 +750,7 @@ func testStreamResources_Server_ServiceUpdates(t *testing.T, disableMeshGateways
testutil.RunStep(t, "deleting the config entry leads to a DELETE event for mongo", func(t *testing.T) {
lastIdx++
err = store.DeleteConfigEntry(lastIdx, structs.ExportedServices, "default", nil)
err := store.DeleteConfigEntry(lastIdx, structs.ExportedServices, "default", nil)
require.NoError(t, err)
retry.Run(t, func(r *retry.R) {
@ -847,6 +763,128 @@ func testStreamResources_Server_ServiceUpdates(t *testing.T, disableMeshGateways
})
}
func TestStreamResources_Server_CARootUpdates(t *testing.T) {
publisher := stream.NewEventPublisher(10 * time.Second)
store := newStateStore(t, publisher)
// Create a peering
var lastIdx uint64 = 1
p := writeInitiatedPeering(t, store, lastIdx, "my-peering")
srv := NewService(
testutil.Logger(t),
Config{
Datacenter: "dc1",
ConnectEnabled: true,
}, &testStreamBackend{
store: store,
pub: publisher,
})
// Set the initial roots and CA configuration.
clusterID, rootA := writeInitialRootsAndCA(t, store)
client := makeClient(t, srv, p.ID, p.PeerID)
testutil.RunStep(t, "initial CA Roots replication", func(t *testing.T) {
expectReplEvents(t, client,
func(t *testing.T, msg *pbpeering.ReplicationMessage) {
require.Equal(t, pbpeering.TypeURLRoots, msg.GetResponse().ResourceURL)
require.Equal(t, "roots", msg.GetResponse().ResourceID)
require.Equal(t, pbpeering.ReplicationMessage_Response_UPSERT, msg.GetResponse().Operation)
var trustBundle pbpeering.PeeringTrustBundle
require.NoError(t, ptypes.UnmarshalAny(msg.GetResponse().Resource, &trustBundle))
require.ElementsMatch(t, []string{rootA.RootCert}, trustBundle.RootPEMs)
expect := connect.SpiffeIDSigningForCluster(clusterID).Host()
require.Equal(t, expect, trustBundle.TrustDomain)
},
)
})
testutil.RunStep(t, "CA root rotation sends upsert event", func(t *testing.T) {
// get max index for CAS operation
cidx, _, err := store.CARoots(nil)
require.NoError(t, err)
rootB := connect.TestCA(t, nil)
rootC := connect.TestCA(t, nil)
rootC.Active = false // there can only be one active root
lastIdx++
set, err := store.CARootSetCAS(lastIdx, cidx, []*structs.CARoot{rootB, rootC})
require.True(t, set)
require.NoError(t, err)
expectReplEvents(t, client,
func(t *testing.T, msg *pbpeering.ReplicationMessage) {
require.Equal(t, pbpeering.TypeURLRoots, msg.GetResponse().ResourceURL)
require.Equal(t, "roots", msg.GetResponse().ResourceID)
require.Equal(t, pbpeering.ReplicationMessage_Response_UPSERT, msg.GetResponse().Operation)
var trustBundle pbpeering.PeeringTrustBundle
require.NoError(t, ptypes.UnmarshalAny(msg.GetResponse().Resource, &trustBundle))
require.ElementsMatch(t, []string{rootB.RootCert, rootC.RootCert}, trustBundle.RootPEMs)
expect := connect.SpiffeIDSigningForCluster(clusterID).Host()
require.Equal(t, expect, trustBundle.TrustDomain)
},
)
})
}
// makeClient sets up a *MockClient with the initial subscription
// message handshake.
func makeClient(
t *testing.T,
srv pbpeering.PeeringServiceServer,
peerID string,
remotePeerID string,
) *MockClient {
t.Helper()
client := NewMockClient(context.Background())
errCh := make(chan error, 1)
client.ErrCh = errCh
go func() {
// Pass errors from server handler into ErrCh so that they can be seen by the client on Recv().
// This matches gRPC's behavior when an error is returned by a server.
if err := srv.StreamResources(client.ReplicationStream); err != nil {
errCh <- srv.StreamResources(client.ReplicationStream)
}
}()
// Issue a services subscription to server
init := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
PeerID: peerID,
ResourceURL: pbpeering.TypeURLService,
},
},
}
require.NoError(t, client.Send(init))
// Receive a services subscription from server
receivedSub, err := client.Recv()
require.NoError(t, err)
expect := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
PeerID: remotePeerID,
},
},
}
prototest.AssertDeepEqual(t, expect, receivedSub)
return client
}
type testStreamBackend struct {
pub state.EventPublisher
store *state.Store
@ -1058,7 +1096,7 @@ func writeInitiatedPeering(t *testing.T, store *state.Store, idx uint64, peerNam
return p
}
func writeInitialRootsAndCA(t *testing.T, store *state.Store) string {
func writeInitialRootsAndCA(t *testing.T, store *state.Store) (string, *structs.CARoot) {
const clusterID = connect.TestClusterID
rootA := connect.TestCA(t, nil)
@ -1068,7 +1106,7 @@ func writeInitialRootsAndCA(t *testing.T, store *state.Store) string {
err = store.CASetConfig(0, &structs.CAConfiguration{ClusterID: clusterID})
require.NoError(t, err)
return clusterID
return clusterID, rootA
}
func makeAnyPB(t *testing.T, pb proto.Message) *any.Any {

View File

@ -2,6 +2,7 @@ package peering
import (
"context"
"errors"
"fmt"
"strings"
@ -11,10 +12,13 @@ import (
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/submatview"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/proto/pbcommon"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/proto/pbservice"
)
@ -75,6 +79,11 @@ func (m *subscriptionManager) subscribe(ctx context.Context, peerID, peerName, p
go m.notifyMeshGatewaysForPartition(ctx, state, state.partition)
}
// If connect is enabled, watch for updates to CA roots.
if m.config.ConnectEnabled {
go m.notifyRootCAUpdates(ctx, state.updateCh)
}
// This goroutine is the only one allowed to manipulate protected
// subscriptionManager fields.
go m.handleEvents(ctx, state, updateCh)
@ -289,6 +298,18 @@ func (m *subscriptionManager) handleEvent(ctx context.Context, state *subscripti
// TODO(peering): should we ship this down verbatim to the consumer?
state.sendPendingEvents(ctx, m.logger, pending)
case u.CorrelationID == subCARoot:
roots, ok := u.Result.(*pbpeering.PeeringTrustBundle)
if !ok {
return fmt.Errorf("invalid type for response: %T", u.Result)
}
pending := &pendingPayload{}
if err := pending.Add(caRootsPayloadID, u.CorrelationID, roots); err != nil {
return err
}
state.sendPendingEvents(ctx, m.logger, pending)
default:
return fmt.Errorf("unknown correlation ID: %s", u.CorrelationID)
}
@ -322,6 +343,106 @@ func filterConnectReferences(orig *pbservice.IndexedCheckServiceNodes) {
orig.Nodes = newNodes
}
func (m *subscriptionManager) notifyRootCAUpdates(ctx context.Context, updateCh chan<- cache.UpdateEvent) {
var idx uint64
// TODO(peering): retry logic; fail past a threshold
for {
var err error
// Typically, this function will block inside `m.subscribeCARoots` and only return on error.
// Errors are logged and the watch is retried.
idx, err = m.subscribeCARoots(ctx, idx, updateCh)
if errors.Is(err, stream.ErrSubForceClosed) {
m.logger.Trace("subscription force-closed due to an ACL change or snapshot restore, will attempt resume")
} else if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
m.logger.Warn("failed to subscribe to CA roots, will attempt resume", "error", err.Error())
} else {
m.logger.Trace(err.Error())
}
select {
case <-ctx.Done():
return
default:
}
}
}
// subscribeCARoots subscribes to state.EventTopicCARoots for changes to CA roots.
// Upon receiving an event it will send the payload in updateCh.
func (m *subscriptionManager) subscribeCARoots(ctx context.Context, idx uint64, updateCh chan<- cache.UpdateEvent) (uint64, error) {
// following code adapted from connectca/watch_roots.go
sub, err := m.backend.Subscribe(&stream.SubscribeRequest{
Topic: state.EventTopicCARoots,
Subject: stream.SubjectNone,
Token: "", // using anonymous token for now
Index: idx,
})
if err != nil {
return 0, fmt.Errorf("failed to subscribe to CA Roots events: %w", err)
}
defer sub.Unsubscribe()
for {
event, err := sub.Next(ctx)
switch {
case errors.Is(err, stream.ErrSubForceClosed):
// If the subscription was closed because the state store was abandoned (e.g.
// following a snapshot restore) reset idx to ensure we don't skip over the
// new store's events.
select {
case <-m.backend.Store().AbandonCh():
idx = 0
default:
}
return idx, err
case errors.Is(err, context.Canceled):
return 0, err
case errors.Is(err, context.DeadlineExceeded):
return 0, err
case err != nil:
return idx, fmt.Errorf("failed to read next event: %w", err)
}
// Note: this check isn't strictly necessary because the event publishing
// machinery will ensure the index increases monotonically, but it can be
// tricky to faithfully reproduce this in tests (e.g. the EventPublisher
// garbage collects topic buffers and snapshots aggressively when streams
// disconnect) so this avoids a bunch of confusing setup code.
if event.Index <= idx {
continue
}
idx = event.Index
// We do not send framing events (e.g. EndOfSnapshot, NewSnapshotToFollow)
// because we send a full list of roots on every event, rather than expecting
// clients to maintain a state-machine in the way they do for service health.
if event.IsFramingEvent() {
continue
}
payload, ok := event.Payload.(state.EventPayloadCARoots)
if !ok {
return 0, fmt.Errorf("unexpected event payload type: %T", payload)
}
var rootPems []string
for _, root := range payload.CARoots {
rootPems = append(rootPems, root.RootCert)
}
updateCh <- cache.UpdateEvent{
CorrelationID: subCARoot,
Result: &pbpeering.PeeringTrustBundle{
TrustDomain: m.trustDomain,
RootPEMs: rootPems,
},
}
}
}
const subCARoot = "roots"
func (m *subscriptionManager) syncNormalServices(
ctx context.Context,
state *subscriptionState,

View File

@ -572,6 +572,65 @@ func testSubscriptionManager_InitialSnapshot(t *testing.T, disableMeshGateways b
})
}
func TestSubscriptionManager_CARoots(t *testing.T) {
backend := newTestSubscriptionBackend(t)
// Setup CA-related configs in the store
clusterID, rootA := writeInitialRootsAndCA(t, backend.store)
trustDomain := connect.SpiffeIDSigningForCluster(clusterID).Host()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create a peering
_, id := backend.ensurePeering(t, "my-peering")
partition := acl.DefaultEnterpriseMeta().PartitionOrEmpty()
mgr := newSubscriptionManager(ctx, testutil.Logger(t), Config{
Datacenter: "dc1",
ConnectEnabled: true,
}, connect.TestTrustDomain, backend)
subCh := mgr.subscribe(ctx, id, "my-peering", partition)
testutil.RunStep(t, "initial events contain trust bundle", func(t *testing.T) {
// events are ordered so we can expect a deterministic list
expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) {
// mesh-gateway assertions are done in other tests
require.Equal(t, subMeshGateway+partition, got.CorrelationID)
},
func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, subCARoot, got.CorrelationID)
roots, ok := got.Result.(*pbpeering.PeeringTrustBundle)
require.True(t, ok)
require.ElementsMatch(t, []string{rootA.RootCert}, roots.RootPEMs)
require.Equal(t, trustDomain, roots.TrustDomain)
},
)
})
testutil.RunStep(t, "updating CA roots triggers event", func(t *testing.T) {
rootB := connect.TestCA(t, nil)
rootC := connect.TestCA(t, nil)
rootC.Active = false // there can only be one active root
backend.ensureCARoots(t, rootB, rootC)
expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, subCARoot, got.CorrelationID)
roots, ok := got.Result.(*pbpeering.PeeringTrustBundle)
require.True(t, ok)
require.ElementsMatch(t, []string{rootB.RootCert, rootC.RootCert}, roots.RootPEMs)
require.Equal(t, trustDomain, roots.TrustDomain)
},
)
})
}
type testSubscriptionBackend struct {
state.EventPublisher
store *state.Store
@ -643,6 +702,24 @@ func (b *testSubscriptionBackend) deleteService(t *testing.T, nodeName, serviceI
return b.lastIdx
}
func (b *testSubscriptionBackend) ensureCAConfig(t *testing.T, config *structs.CAConfiguration) uint64 {
b.lastIdx++
require.NoError(t, b.store.CASetConfig(b.lastIdx, config))
return b.lastIdx
}
func (b *testSubscriptionBackend) ensureCARoots(t *testing.T, roots ...*structs.CARoot) uint64 {
// Get the max index for Check-and-Set operation
cidx, _, err := b.store.CARoots(nil)
require.NoError(t, err)
b.lastIdx++
set, err := b.store.CARootSetCAS(b.lastIdx, cidx, roots)
require.True(t, set)
require.NoError(t, err)
return b.lastIdx
}
func setupTestPeering(t *testing.T, store *state.Store, name string, index uint64) string {
err := store.PeeringWrite(index, &pbpeering.Peering{
Name: name,
@ -666,6 +743,7 @@ func newStateStore(t *testing.T, publisher *stream.EventPublisher) *state.Store
store := state.NewStateStoreWithEventPublisher(gc, publisher)
require.NoError(t, publisher.RegisterHandler(state.EventTopicServiceHealth, store.ServiceHealthSnapshot))
require.NoError(t, publisher.RegisterHandler(state.EventTopicServiceHealthConnect, store.ServiceHealthSnapshot))
require.NoError(t, publisher.RegisterHandler(state.EventTopicCARoots, store.CARootsSnapshot))
go publisher.Run(ctx)
return store

View File

@ -92,6 +92,9 @@ func (s *subscriptionState) cleanupEventVersions(logger hclog.Logger) {
case id == meshGatewayPayloadID:
keep = true
case id == caRootsPayloadID:
keep = true
case strings.HasPrefix(id, servicePayloadIDPrefix):
name := strings.TrimPrefix(id, servicePayloadIDPrefix)
sn := structs.ServiceNameFromString(name)
@ -136,6 +139,7 @@ type pendingEvent struct {
}
const (
caRootsPayloadID = "roots"
meshGatewayPayloadID = "mesh-gateway"
servicePayloadIDPrefix = "service:"
proxyServicePayloadIDPrefix = "proxy-service:" // TODO(peering): remove

View File

@ -2,4 +2,9 @@ package pbpeering
const (
TypeURLService = "type.googleapis.com/consul.api.Service"
TypeURLRoots = "type.googleapis.com/consul.api.CARoots"
)
func KnownTypeURL(s string) bool {
return s == TypeURLService || s == TypeURLRoots
}