diff --git a/agent/consul/peering_backend.go b/agent/consul/peering_backend.go index 9ec6639c5f..84e1676bdb 100644 --- a/agent/consul/peering_backend.go +++ b/agent/consul/peering_backend.go @@ -129,6 +129,11 @@ func (a *peeringApply) PeeringTerminateByID(req *pbpeering.PeeringTerminateByIDR return err } +func (a *peeringApply) PeeringTrustBundleWrite(req *pbpeering.PeeringTrustBundleWriteRequest) error { + _, err := a.srv.raftApplyProtobuf(structs.PeeringTrustBundleWriteType, req) + return err +} + func (a *peeringApply) CatalogRegister(req *structs.RegisterRequest) error { _, err := a.srv.leaderRaftApply("Catalog.Register", structs.RegisterRequestType, req) return err diff --git a/agent/rpc/peering/replication.go b/agent/rpc/peering/replication.go index 14513b01bb..1f546bdd32 100644 --- a/agent/rpc/peering/replication.go +++ b/agent/rpc/peering/replication.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" + "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/hashicorp/go-hclog" "google.golang.org/genproto/googleapis/rpc/code" @@ -29,20 +30,17 @@ import ( request a resync operation. */ -// pushService response handles sending exported service instance updates to the peer cluster. +// makeServiceResponse handles preparing exported service instance updates to the peer cluster. // Each cache.UpdateEvent will contain all instances for a service name. // If there are no instances in the event, we consider that to be a de-registration. -func pushServiceResponse( +func makeServiceResponse( logger hclog.Logger, - stream BidirectionalStream, - status *lockableStreamStatus, update cache.UpdateEvent, -) error { - csn, ok := update.Result.(*pbservice.IndexedCheckServiceNodes) - if !ok { - logger.Error(fmt.Sprintf("invalid type for response: %T, expected *pbservice.IndexedCheckServiceNodes", update.Result)) - - // Skip this update to avoid locking up peering due to a bad service update. +) *pbpeering.ReplicationMessage { + any, csn, err := marshalToProtoAny[*pbservice.IndexedCheckServiceNodes](update.Result) + if err != nil { + // Log the error and skip this response to avoid locking up peering due to a bad update event. + logger.Error("failed to marshal", "error", err) return nil } @@ -72,21 +70,10 @@ func pushServiceResponse( }, }, } - logTraceSend(logger, resp) - if err := stream.Send(resp); err != nil { - status.trackSendError(err.Error()) - return fmt.Errorf("failed to send to stream: %v", err) - } - return nil + return resp } // If there are nodes in the response, we push them as an UPSERT operation. - any, err := ptypes.MarshalAny(csn) - if err != nil { - // Log the error and skip this response to avoid locking up peering due to a bad update event. - logger.Error("failed to marshal service endpoints", "error", err) - return nil - } resp := &pbpeering.ReplicationMessage{ Payload: &pbpeering.ReplicationMessage_Response_{ Response: &pbpeering.ReplicationMessage_Response{ @@ -99,16 +86,58 @@ func pushServiceResponse( }, }, } - logTraceSend(logger, resp) - if err := stream.Send(resp); err != nil { - status.trackSendError(err.Error()) - return fmt.Errorf("failed to send to stream: %v", err) - } - return nil + return resp } -func (s *Service) processResponse(peerName string, partition string, resp *pbpeering.ReplicationMessage_Response) (*pbpeering.ReplicationMessage, error) { - if resp.ResourceURL != pbpeering.TypeURLService { +func makeCARootsResponse( + logger hclog.Logger, + update cache.UpdateEvent, +) *pbpeering.ReplicationMessage { + any, _, err := marshalToProtoAny[*pbpeering.PeeringTrustBundle](update.Result) + if err != nil { + // Log the error and skip this response to avoid locking up peering due to a bad update event. + logger.Error("failed to marshal", "error", err) + return nil + } + + resp := &pbpeering.ReplicationMessage{ + Payload: &pbpeering.ReplicationMessage_Response_{ + Response: &pbpeering.ReplicationMessage_Response{ + ResourceURL: pbpeering.TypeURLRoots, + // TODO(peering): Nonce management + Nonce: "", + ResourceID: "roots", + Operation: pbpeering.ReplicationMessage_Response_UPSERT, + Resource: any, + }, + }, + } + return resp +} + +// marshalToProtoAny takes any input and returns: +// the protobuf.Any type, the asserted T type, and any errors +// during marshalling or type assertion. +// `in` MUST be of type T or it returns an error. +func marshalToProtoAny[T proto.Message](in any) (*anypb.Any, T, error) { + typ, ok := in.(T) + if !ok { + var outType T + return nil, typ, fmt.Errorf("input type is not %T: %T", outType, in) + } + any, err := ptypes.MarshalAny(typ) + if err != nil { + return nil, typ, err + } + return any, typ, nil +} + +func (s *Service) processResponse( + peerName string, + partition string, + resp *pbpeering.ReplicationMessage_Response, +) (*pbpeering.ReplicationMessage, error) { + if !pbpeering.KnownTypeURL(resp.ResourceURL) { err := fmt.Errorf("received response for unknown resource type %q", resp.ResourceURL) return makeReply( resp.ResourceURL, @@ -186,6 +215,15 @@ func (s *Service) handleUpsert( } return s.handleUpsertService(peerName, partition, sn, csn) + + case pbpeering.TypeURLRoots: + roots := &pbpeering.PeeringTrustBundle{} + if err := ptypes.UnmarshalAny(resource, roots); err != nil { + return fmt.Errorf("failed to unmarshal resource: %w", err) + } + + return s.handleUpsertRoots(peerName, partition, roots) + default: return fmt.Errorf("unexpected resourceURL: %s", resourceURL) } @@ -249,6 +287,21 @@ func (s *Service) handleUpsertService( return nil } +func (s *Service) handleUpsertRoots( + peerName string, + partition string, + trustBundle *pbpeering.PeeringTrustBundle, +) error { + // We override the partition and peer name so that the trust bundle gets stored + // in the importing partition with a reference to the peer it was imported from. + trustBundle.Partition = partition + trustBundle.PeerName = peerName + req := &pbpeering.PeeringTrustBundleWriteRequest{ + PeeringTrustBundle: trustBundle, + } + return s.Backend.Apply().PeeringTrustBundleWrite(req) +} + func (s *Service) handleDelete( peerName string, partition string, diff --git a/agent/rpc/peering/service.go b/agent/rpc/peering/service.go index 81243eae3e..c50c825eb7 100644 --- a/agent/rpc/peering/service.go +++ b/agent/rpc/peering/service.go @@ -118,7 +118,7 @@ type Store interface { ExportedServicesForPeer(ws memdb.WatchSet, peerID string) (uint64, *structs.ExportedServiceList, error) PeeringsForService(ws memdb.WatchSet, serviceName string, entMeta acl.EnterpriseMeta) (uint64, []*pbpeering.Peering, error) ServiceDump(ws memdb.WatchSet, kind structs.ServiceKind, useKind bool, entMeta *acl.EnterpriseMeta, peerName string) (uint64, structs.CheckServiceNodes, error) - CAConfig(memdb.WatchSet) (uint64, *structs.CAConfiguration, error) + CAConfig(ws memdb.WatchSet) (uint64, *structs.CAConfiguration, error) AbandonCh() <-chan struct{} } @@ -127,6 +127,7 @@ type Apply interface { PeeringWrite(req *pbpeering.PeeringWriteRequest) error PeeringDelete(req *pbpeering.PeeringDeleteRequest) error PeeringTerminateByID(req *pbpeering.PeeringTerminateByIDRequest) error + PeeringTrustBundleWrite(req *pbpeering.PeeringTrustBundleWriteRequest) error CatalogRegister(req *structs.RegisterRequest) error } @@ -469,7 +470,7 @@ func (s *Service) StreamResources(stream pbpeering.PeeringService_StreamResource if req.Nonce != "" { return grpcstatus.Error(codes.InvalidArgument, "initial subscription request must not contain a nonce") } - if req.ResourceURL != pbpeering.TypeURLService { + if !pbpeering.KnownTypeURL(req.ResourceURL) { return grpcstatus.Error(codes.InvalidArgument, fmt.Sprintf("subscription request to unknown resource URL: %s", req.ResourceURL)) } @@ -680,18 +681,30 @@ func (s *Service) HandleStream(req HandleStreamRequest) error { } case update := <-subCh: + var resp *pbpeering.ReplicationMessage switch { case strings.HasPrefix(update.CorrelationID, subExportedService), strings.HasPrefix(update.CorrelationID, subExportedProxyService): - if err := pushServiceResponse(logger, req.Stream, status, update); err != nil { - return fmt.Errorf("failed to push data for %q: %w", update.CorrelationID, err) - } + resp = makeServiceResponse(logger, update) + case strings.HasPrefix(update.CorrelationID, subMeshGateway): - //TODO(Peering): figure out how to sync this separately + // TODO(Peering): figure out how to sync this separately + + case update.CorrelationID == subCARoot: + resp = makeCARootsResponse(logger, update) + default: logger.Warn("unrecognized update type from subscription manager: " + update.CorrelationID) continue } + if resp == nil { + continue + } + logTraceSend(logger, resp) + if err := req.Stream.Send(resp); err != nil { + status.trackSendError(err.Error()) + return fmt.Errorf("failed to push data for %q: %w", update.CorrelationID, err) + } } } } diff --git a/agent/rpc/peering/service_test.go b/agent/rpc/peering/service_test.go index 587d1ac407..77bcd5c7af 100644 --- a/agent/rpc/peering/service_test.go +++ b/agent/rpc/peering/service_test.go @@ -11,15 +11,14 @@ import ( "testing" "time" - "github.com/hashicorp/consul/agent/consul/state" - - "github.com/golang/protobuf/ptypes" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" gogrpc "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/anypb" + "github.com/hashicorp/consul/agent/consul/state" grpc "github.com/hashicorp/consul/agent/grpc/private" "github.com/hashicorp/consul/agent/grpc/private/resolver" "github.com/hashicorp/consul/api" @@ -725,6 +724,12 @@ func Test_StreamHandler_UpsertServices(t *testing.T) { _, err = client.Recv() require.NoError(t, err) + // Receive first roots replication message + receiveRoots, err := client.Recv() + require.NoError(t, err) + require.NotNil(t, receiveRoots.GetResponse()) + require.Equal(t, pbpeering.TypeURLRoots, receiveRoots.GetResponse().ResourceURL) + remoteEntMeta := structs.DefaultEnterpriseMetaInPartition("remote-partition") localEntMeta := acl.DefaultEnterpriseMeta() localPeerName := "my-peer" @@ -752,7 +757,7 @@ func Test_StreamHandler_UpsertServices(t *testing.T) { pbCSN.Nodes = append(pbCSN.Nodes, pbservice.NewCheckServiceNodeFromStructs(&csn)) } - any, err := ptypes.MarshalAny(pbCSN) + any, err := anypb.New(pbCSN) require.NoError(t, err) tc.msg.Resource = any diff --git a/agent/rpc/peering/stream_test.go b/agent/rpc/peering/stream_test.go index 6f58a9dbf4..27016a1582 100644 --- a/agent/rpc/peering/stream_test.go +++ b/agent/rpc/peering/stream_test.go @@ -112,7 +112,7 @@ func TestStreamResources_Server_LeaderBecomesFollower(t *testing.T) { peerID := p.ID // Set the initial roots and CA configuration. - _ = writeInitialRootsAndCA(t, store) + _, _ = writeInitialRootsAndCA(t, store) // Receive a subscription from a peer sub := &pbpeering.ReplicationMessage{ @@ -130,6 +130,11 @@ func TestStreamResources_Server_LeaderBecomesFollower(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, msg) + receiveRoots, err := client.Recv() + require.NoError(t, err) + require.NotNil(t, receiveRoots.GetResponse()) + require.Equal(t, pbpeering.TypeURLRoots, receiveRoots.GetResponse().ResourceURL) + input2 := &pbpeering.ReplicationMessage{ Payload: &pbpeering.ReplicationMessage_Request_{ Request: &pbpeering.ReplicationMessage_Request{ @@ -280,19 +285,6 @@ func TestStreamResources_Server_Terminate(t *testing.T) { } srv.streams.timeNow = it.Now - client := NewMockClient(context.Background()) - - errCh := make(chan error, 1) - client.ErrCh = errCh - - go func() { - // Pass errors from server handler into ErrCh so that they can be seen by the client on Recv(). - // This matches gRPC's behavior when an error is returned by a server. - if err := srv.StreamResources(client.ReplicationStream); err != nil { - errCh <- err - } - }() - p := writeInitiatedPeering(t, store, 1, "my-peer") var ( peerID = p.ID // for Send @@ -300,19 +292,17 @@ func TestStreamResources_Server_Terminate(t *testing.T) { ) // Set the initial roots and CA configuration. - _ = writeInitialRootsAndCA(t, store) + _, _ = writeInitialRootsAndCA(t, store) - // Receive a subscription from a peer - sub := &pbpeering.ReplicationMessage{ - Payload: &pbpeering.ReplicationMessage_Request_{ - Request: &pbpeering.ReplicationMessage_Request{ - PeerID: peerID, - ResourceURL: pbpeering.TypeURLService, - }, - }, - } - err := client.Send(sub) + client := makeClient(t, srv, peerID, remotePeerID) + + // TODO(peering): test fails if we don't drain the stream with this call because the + // server gets blocked sending the termination message. Figure out a way to let + // messages queue and filter replication messages. + receiveRoots, err := client.Recv() require.NoError(t, err) + require.NotNil(t, receiveRoots.GetResponse()) + require.Equal(t, pbpeering.TypeURLRoots, receiveRoots.GetResponse().ResourceURL) testutil.RunStep(t, "new stream gets tracked", func(t *testing.T) { retry.Run(t, func(r *retry.R) { @@ -322,20 +312,6 @@ func TestStreamResources_Server_Terminate(t *testing.T) { }) }) - // Receive subscription to my-peer-B's resources - receivedSub, err := client.Recv() - require.NoError(t, err) - - expect := &pbpeering.ReplicationMessage{ - Payload: &pbpeering.ReplicationMessage_Request_{ - Request: &pbpeering.ReplicationMessage_Request{ - ResourceURL: pbpeering.TypeURLService, - PeerID: remotePeerID, - }, - }, - } - prototest.AssertDeepEqual(t, expect, receivedSub) - testutil.RunStep(t, "terminate the stream", func(t *testing.T) { done := srv.ConnectedStreams()[peerID] close(done) @@ -348,7 +324,7 @@ func TestStreamResources_Server_Terminate(t *testing.T) { receivedTerm, err := client.Recv() require.NoError(t, err) - expect = &pbpeering.ReplicationMessage{ + expect := &pbpeering.ReplicationMessage{ Payload: &pbpeering.ReplicationMessage_Terminated_{ Terminated: &pbpeering.ReplicationMessage_Terminated{}, }, @@ -375,12 +351,8 @@ func TestStreamResources_Server_StreamTracker(t *testing.T) { } srv.streams.timeNow = it.Now - client := NewMockClient(context.Background()) - - errCh := make(chan error, 1) - go func() { - errCh <- srv.StreamResources(client.ReplicationStream) - }() + // Set the initial roots and CA configuration. + _, rootA := writeInitialRootsAndCA(t, store) p := writeInitiatedPeering(t, store, 1, "my-peer") var ( @@ -388,20 +360,7 @@ func TestStreamResources_Server_StreamTracker(t *testing.T) { remotePeerID = p.PeerID // for Recv ) - // Set the initial roots and CA configuration. - _ = writeInitialRootsAndCA(t, store) - - // Receive a subscription from a peer - sub := &pbpeering.ReplicationMessage{ - Payload: &pbpeering.ReplicationMessage_Request_{ - Request: &pbpeering.ReplicationMessage_Request{ - PeerID: peerID, - ResourceURL: pbpeering.TypeURLService, - }, - }, - } - err := client.Send(sub) - require.NoError(t, err) + client := makeClient(t, srv, peerID, remotePeerID) testutil.RunStep(t, "new stream gets tracked", func(t *testing.T) { retry.Run(t, func(r *retry.R) { @@ -411,22 +370,6 @@ func TestStreamResources_Server_StreamTracker(t *testing.T) { }) }) - testutil.RunStep(t, "client receives initial subscription", func(t *testing.T) { - ack, err := client.Recv() - require.NoError(t, err) - - expectAck := &pbpeering.ReplicationMessage{ - Payload: &pbpeering.ReplicationMessage_Request_{ - Request: &pbpeering.ReplicationMessage_Request{ - ResourceURL: pbpeering.TypeURLService, - PeerID: remotePeerID, - Nonce: "", - }, - }, - } - prototest.AssertDeepEqual(t, expectAck, ack) - }) - var sequence uint64 var lastSendSuccess time.Time @@ -516,6 +459,24 @@ func TestStreamResources_Server_StreamTracker(t *testing.T) { require.NoError(t, err) sequence++ + expectRoots := &pbpeering.ReplicationMessage{ + Payload: &pbpeering.ReplicationMessage_Response_{ + Response: &pbpeering.ReplicationMessage_Response{ + ResourceURL: pbpeering.TypeURLRoots, + ResourceID: "roots", + Resource: makeAnyPB(t, &pbpeering.PeeringTrustBundle{ + TrustDomain: connect.TestTrustDomain, + RootPEMs: []string{rootA.RootCert}, + }), + Operation: pbpeering.ReplicationMessage_Response_UPSERT, + }, + }, + } + + roots, err := client.Recv() + require.NoError(t, err) + prototest.AssertDeepEqual(t, expectRoots, roots) + ack, err := client.Recv() require.NoError(t, err) @@ -629,14 +590,6 @@ func TestStreamResources_Server_StreamTracker(t *testing.T) { require.Equal(r, expect, status) }) }) - - select { - case err := <-errCh: - // Client disconnect is not an error, but should make the handler return. - require.NoError(t, err) - case <-time.After(50 * time.Millisecond): - t.Fatalf("timed out waiting for handler to finish") - } } func TestStreamResources_Server_ServiceUpdates(t *testing.T) { @@ -652,13 +605,9 @@ func testStreamResources_Server_ServiceUpdates(t *testing.T, disableMeshGateways // Create a peering var lastIdx uint64 = 1 p := writeInitiatedPeering(t, store, lastIdx, "my-peering") - var ( - peerID = p.ID // for Send - remotePeerID = p.PeerID // for Recv - ) // Set the initial roots and CA configuration. - _ = writeInitialRootsAndCA(t, store) + _, _ = writeInitialRootsAndCA(t, store) srv := NewService( testutil.Logger(t), @@ -670,44 +619,7 @@ func testStreamResources_Server_ServiceUpdates(t *testing.T, disableMeshGateways store: store, pub: publisher, }) - - client := NewMockClient(context.Background()) - - errCh := make(chan error, 1) - client.ErrCh = errCh - - go func() { - // Pass errors from server handler into ErrCh so that they can be seen by the client on Recv(). - // This matches gRPC's behavior when an error is returned by a server. - if err := srv.StreamResources(client.ReplicationStream); err != nil { - errCh <- err - } - }() - - // Issue a services subscription to server - init := &pbpeering.ReplicationMessage{ - Payload: &pbpeering.ReplicationMessage_Request_{ - Request: &pbpeering.ReplicationMessage_Request{ - PeerID: peerID, - ResourceURL: pbpeering.TypeURLService, - }, - }, - } - require.NoError(t, client.Send(init)) - - // Receive a services subscription from server - receivedSub, err := client.Recv() - require.NoError(t, err) - - expect := &pbpeering.ReplicationMessage{ - Payload: &pbpeering.ReplicationMessage_Request_{ - Request: &pbpeering.ReplicationMessage_Request{ - ResourceURL: pbpeering.TypeURLService, - PeerID: remotePeerID, - }, - }, - } - prototest.AssertDeepEqual(t, expect, receivedSub) + client := makeClient(t, srv, p.ID, p.PeerID) // Register a service that is not yet exported mysql := &structs.CheckServiceNode{ @@ -751,6 +663,10 @@ func testStreamResources_Server_ServiceUpdates(t *testing.T, disableMeshGateways require.NoError(t, store.EnsureConfigEntry(lastIdx, entry)) expectReplEvents(t, client, + func(t *testing.T, msg *pbpeering.ReplicationMessage) { + require.Equal(t, pbpeering.TypeURLRoots, msg.GetResponse().ResourceURL) + // Roots tested in TestStreamResources_Server_CARootUpdates + }, func(t *testing.T, msg *pbpeering.ReplicationMessage) { require.Equal(t, pbpeering.TypeURLService, msg.GetResponse().ResourceURL) require.Equal(t, mongoSN, msg.GetResponse().ResourceID) @@ -820,7 +736,7 @@ func testStreamResources_Server_ServiceUpdates(t *testing.T, disableMeshGateways }, } lastIdx++ - err = store.EnsureConfigEntry(lastIdx, entry) + err := store.EnsureConfigEntry(lastIdx, entry) require.NoError(t, err) retry.Run(t, func(r *retry.R) { @@ -834,7 +750,7 @@ func testStreamResources_Server_ServiceUpdates(t *testing.T, disableMeshGateways testutil.RunStep(t, "deleting the config entry leads to a DELETE event for mongo", func(t *testing.T) { lastIdx++ - err = store.DeleteConfigEntry(lastIdx, structs.ExportedServices, "default", nil) + err := store.DeleteConfigEntry(lastIdx, structs.ExportedServices, "default", nil) require.NoError(t, err) retry.Run(t, func(r *retry.R) { @@ -847,6 +763,128 @@ func testStreamResources_Server_ServiceUpdates(t *testing.T, disableMeshGateways }) } +func TestStreamResources_Server_CARootUpdates(t *testing.T) { + publisher := stream.NewEventPublisher(10 * time.Second) + + store := newStateStore(t, publisher) + + // Create a peering + var lastIdx uint64 = 1 + p := writeInitiatedPeering(t, store, lastIdx, "my-peering") + + srv := NewService( + testutil.Logger(t), + Config{ + Datacenter: "dc1", + ConnectEnabled: true, + }, &testStreamBackend{ + store: store, + pub: publisher, + }) + + // Set the initial roots and CA configuration. + clusterID, rootA := writeInitialRootsAndCA(t, store) + + client := makeClient(t, srv, p.ID, p.PeerID) + + testutil.RunStep(t, "initial CA Roots replication", func(t *testing.T) { + expectReplEvents(t, client, + func(t *testing.T, msg *pbpeering.ReplicationMessage) { + require.Equal(t, pbpeering.TypeURLRoots, msg.GetResponse().ResourceURL) + require.Equal(t, "roots", msg.GetResponse().ResourceID) + require.Equal(t, pbpeering.ReplicationMessage_Response_UPSERT, msg.GetResponse().Operation) + + var trustBundle pbpeering.PeeringTrustBundle + require.NoError(t, ptypes.UnmarshalAny(msg.GetResponse().Resource, &trustBundle)) + + require.ElementsMatch(t, []string{rootA.RootCert}, trustBundle.RootPEMs) + expect := connect.SpiffeIDSigningForCluster(clusterID).Host() + require.Equal(t, expect, trustBundle.TrustDomain) + }, + ) + }) + + testutil.RunStep(t, "CA root rotation sends upsert event", func(t *testing.T) { + // get max index for CAS operation + cidx, _, err := store.CARoots(nil) + require.NoError(t, err) + + rootB := connect.TestCA(t, nil) + rootC := connect.TestCA(t, nil) + rootC.Active = false // there can only be one active root + lastIdx++ + set, err := store.CARootSetCAS(lastIdx, cidx, []*structs.CARoot{rootB, rootC}) + require.True(t, set) + require.NoError(t, err) + + expectReplEvents(t, client, + func(t *testing.T, msg *pbpeering.ReplicationMessage) { + require.Equal(t, pbpeering.TypeURLRoots, msg.GetResponse().ResourceURL) + require.Equal(t, "roots", msg.GetResponse().ResourceID) + require.Equal(t, pbpeering.ReplicationMessage_Response_UPSERT, msg.GetResponse().Operation) + + var trustBundle pbpeering.PeeringTrustBundle + require.NoError(t, ptypes.UnmarshalAny(msg.GetResponse().Resource, &trustBundle)) + + require.ElementsMatch(t, []string{rootB.RootCert, rootC.RootCert}, trustBundle.RootPEMs) + expect := connect.SpiffeIDSigningForCluster(clusterID).Host() + require.Equal(t, expect, trustBundle.TrustDomain) + }, + ) + }) +} + +// makeClient sets up a *MockClient with the initial subscription +// message handshake. +func makeClient( + t *testing.T, + srv pbpeering.PeeringServiceServer, + peerID string, + remotePeerID string, +) *MockClient { + t.Helper() + + client := NewMockClient(context.Background()) + + errCh := make(chan error, 1) + client.ErrCh = errCh + + go func() { + // Pass errors from server handler into ErrCh so that they can be seen by the client on Recv(). + // This matches gRPC's behavior when an error is returned by a server. + if err := srv.StreamResources(client.ReplicationStream); err != nil { + errCh <- srv.StreamResources(client.ReplicationStream) + } + }() + + // Issue a services subscription to server + init := &pbpeering.ReplicationMessage{ + Payload: &pbpeering.ReplicationMessage_Request_{ + Request: &pbpeering.ReplicationMessage_Request{ + PeerID: peerID, + ResourceURL: pbpeering.TypeURLService, + }, + }, + } + require.NoError(t, client.Send(init)) + + // Receive a services subscription from server + receivedSub, err := client.Recv() + require.NoError(t, err) + + expect := &pbpeering.ReplicationMessage{ + Payload: &pbpeering.ReplicationMessage_Request_{ + Request: &pbpeering.ReplicationMessage_Request{ + ResourceURL: pbpeering.TypeURLService, + PeerID: remotePeerID, + }, + }, + } + prototest.AssertDeepEqual(t, expect, receivedSub) + + return client +} + type testStreamBackend struct { pub state.EventPublisher store *state.Store @@ -1058,7 +1096,7 @@ func writeInitiatedPeering(t *testing.T, store *state.Store, idx uint64, peerNam return p } -func writeInitialRootsAndCA(t *testing.T, store *state.Store) string { +func writeInitialRootsAndCA(t *testing.T, store *state.Store) (string, *structs.CARoot) { const clusterID = connect.TestClusterID rootA := connect.TestCA(t, nil) @@ -1068,7 +1106,7 @@ func writeInitialRootsAndCA(t *testing.T, store *state.Store) string { err = store.CASetConfig(0, &structs.CAConfiguration{ClusterID: clusterID}) require.NoError(t, err) - return clusterID + return clusterID, rootA } func makeAnyPB(t *testing.T, pb proto.Message) *any.Any { diff --git a/agent/rpc/peering/subscription_manager.go b/agent/rpc/peering/subscription_manager.go index d4ac7d7699..5a48900b86 100644 --- a/agent/rpc/peering/subscription_manager.go +++ b/agent/rpc/peering/subscription_manager.go @@ -2,6 +2,7 @@ package peering import ( "context" + "errors" "fmt" "strings" @@ -11,10 +12,13 @@ import ( "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/agent/consul/state" + "github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/submatview" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/proto/pbcommon" + "github.com/hashicorp/consul/proto/pbpeering" "github.com/hashicorp/consul/proto/pbservice" ) @@ -75,6 +79,11 @@ func (m *subscriptionManager) subscribe(ctx context.Context, peerID, peerName, p go m.notifyMeshGatewaysForPartition(ctx, state, state.partition) } + // If connect is enabled, watch for updates to CA roots. + if m.config.ConnectEnabled { + go m.notifyRootCAUpdates(ctx, state.updateCh) + } + // This goroutine is the only one allowed to manipulate protected // subscriptionManager fields. go m.handleEvents(ctx, state, updateCh) @@ -289,6 +298,18 @@ func (m *subscriptionManager) handleEvent(ctx context.Context, state *subscripti // TODO(peering): should we ship this down verbatim to the consumer? state.sendPendingEvents(ctx, m.logger, pending) + case u.CorrelationID == subCARoot: + roots, ok := u.Result.(*pbpeering.PeeringTrustBundle) + if !ok { + return fmt.Errorf("invalid type for response: %T", u.Result) + } + pending := &pendingPayload{} + if err := pending.Add(caRootsPayloadID, u.CorrelationID, roots); err != nil { + return err + } + + state.sendPendingEvents(ctx, m.logger, pending) + default: return fmt.Errorf("unknown correlation ID: %s", u.CorrelationID) } @@ -322,6 +343,106 @@ func filterConnectReferences(orig *pbservice.IndexedCheckServiceNodes) { orig.Nodes = newNodes } +func (m *subscriptionManager) notifyRootCAUpdates(ctx context.Context, updateCh chan<- cache.UpdateEvent) { + var idx uint64 + // TODO(peering): retry logic; fail past a threshold + for { + var err error + // Typically, this function will block inside `m.subscribeCARoots` and only return on error. + // Errors are logged and the watch is retried. + idx, err = m.subscribeCARoots(ctx, idx, updateCh) + if errors.Is(err, stream.ErrSubForceClosed) { + m.logger.Trace("subscription force-closed due to an ACL change or snapshot restore, will attempt resume") + } else if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + m.logger.Warn("failed to subscribe to CA roots, will attempt resume", "error", err.Error()) + } else { + m.logger.Trace(err.Error()) + } + + select { + case <-ctx.Done(): + return + default: + } + } +} + +// subscribeCARoots subscribes to state.EventTopicCARoots for changes to CA roots. +// Upon receiving an event it will send the payload in updateCh. +func (m *subscriptionManager) subscribeCARoots(ctx context.Context, idx uint64, updateCh chan<- cache.UpdateEvent) (uint64, error) { + // following code adapted from connectca/watch_roots.go + sub, err := m.backend.Subscribe(&stream.SubscribeRequest{ + Topic: state.EventTopicCARoots, + Subject: stream.SubjectNone, + Token: "", // using anonymous token for now + Index: idx, + }) + if err != nil { + return 0, fmt.Errorf("failed to subscribe to CA Roots events: %w", err) + } + defer sub.Unsubscribe() + + for { + event, err := sub.Next(ctx) + switch { + case errors.Is(err, stream.ErrSubForceClosed): + // If the subscription was closed because the state store was abandoned (e.g. + // following a snapshot restore) reset idx to ensure we don't skip over the + // new store's events. + select { + case <-m.backend.Store().AbandonCh(): + idx = 0 + default: + } + return idx, err + case errors.Is(err, context.Canceled): + return 0, err + case errors.Is(err, context.DeadlineExceeded): + return 0, err + case err != nil: + return idx, fmt.Errorf("failed to read next event: %w", err) + } + + // Note: this check isn't strictly necessary because the event publishing + // machinery will ensure the index increases monotonically, but it can be + // tricky to faithfully reproduce this in tests (e.g. the EventPublisher + // garbage collects topic buffers and snapshots aggressively when streams + // disconnect) so this avoids a bunch of confusing setup code. + if event.Index <= idx { + continue + } + + idx = event.Index + + // We do not send framing events (e.g. EndOfSnapshot, NewSnapshotToFollow) + // because we send a full list of roots on every event, rather than expecting + // clients to maintain a state-machine in the way they do for service health. + if event.IsFramingEvent() { + continue + } + + payload, ok := event.Payload.(state.EventPayloadCARoots) + if !ok { + return 0, fmt.Errorf("unexpected event payload type: %T", payload) + } + + var rootPems []string + for _, root := range payload.CARoots { + rootPems = append(rootPems, root.RootCert) + } + + updateCh <- cache.UpdateEvent{ + CorrelationID: subCARoot, + Result: &pbpeering.PeeringTrustBundle{ + TrustDomain: m.trustDomain, + RootPEMs: rootPems, + }, + } + } +} + +const subCARoot = "roots" + func (m *subscriptionManager) syncNormalServices( ctx context.Context, state *subscriptionState, diff --git a/agent/rpc/peering/subscription_manager_test.go b/agent/rpc/peering/subscription_manager_test.go index 28583cfaff..299fd9ad81 100644 --- a/agent/rpc/peering/subscription_manager_test.go +++ b/agent/rpc/peering/subscription_manager_test.go @@ -572,6 +572,65 @@ func testSubscriptionManager_InitialSnapshot(t *testing.T, disableMeshGateways b }) } +func TestSubscriptionManager_CARoots(t *testing.T) { + backend := newTestSubscriptionBackend(t) + + // Setup CA-related configs in the store + clusterID, rootA := writeInitialRootsAndCA(t, backend.store) + trustDomain := connect.SpiffeIDSigningForCluster(clusterID).Host() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create a peering + _, id := backend.ensurePeering(t, "my-peering") + partition := acl.DefaultEnterpriseMeta().PartitionOrEmpty() + + mgr := newSubscriptionManager(ctx, testutil.Logger(t), Config{ + Datacenter: "dc1", + ConnectEnabled: true, + }, connect.TestTrustDomain, backend) + subCh := mgr.subscribe(ctx, id, "my-peering", partition) + + testutil.RunStep(t, "initial events contain trust bundle", func(t *testing.T) { + // events are ordered so we can expect a deterministic list + expectEvents(t, subCh, + func(t *testing.T, got cache.UpdateEvent) { + // mesh-gateway assertions are done in other tests + require.Equal(t, subMeshGateway+partition, got.CorrelationID) + }, + func(t *testing.T, got cache.UpdateEvent) { + require.Equal(t, subCARoot, got.CorrelationID) + roots, ok := got.Result.(*pbpeering.PeeringTrustBundle) + require.True(t, ok) + + require.ElementsMatch(t, []string{rootA.RootCert}, roots.RootPEMs) + + require.Equal(t, trustDomain, roots.TrustDomain) + }, + ) + }) + + testutil.RunStep(t, "updating CA roots triggers event", func(t *testing.T) { + rootB := connect.TestCA(t, nil) + rootC := connect.TestCA(t, nil) + rootC.Active = false // there can only be one active root + backend.ensureCARoots(t, rootB, rootC) + + expectEvents(t, subCh, + func(t *testing.T, got cache.UpdateEvent) { + require.Equal(t, subCARoot, got.CorrelationID) + roots, ok := got.Result.(*pbpeering.PeeringTrustBundle) + require.True(t, ok) + + require.ElementsMatch(t, []string{rootB.RootCert, rootC.RootCert}, roots.RootPEMs) + + require.Equal(t, trustDomain, roots.TrustDomain) + }, + ) + }) +} + type testSubscriptionBackend struct { state.EventPublisher store *state.Store @@ -643,6 +702,24 @@ func (b *testSubscriptionBackend) deleteService(t *testing.T, nodeName, serviceI return b.lastIdx } +func (b *testSubscriptionBackend) ensureCAConfig(t *testing.T, config *structs.CAConfiguration) uint64 { + b.lastIdx++ + require.NoError(t, b.store.CASetConfig(b.lastIdx, config)) + return b.lastIdx +} + +func (b *testSubscriptionBackend) ensureCARoots(t *testing.T, roots ...*structs.CARoot) uint64 { + // Get the max index for Check-and-Set operation + cidx, _, err := b.store.CARoots(nil) + require.NoError(t, err) + + b.lastIdx++ + set, err := b.store.CARootSetCAS(b.lastIdx, cidx, roots) + require.True(t, set) + require.NoError(t, err) + return b.lastIdx +} + func setupTestPeering(t *testing.T, store *state.Store, name string, index uint64) string { err := store.PeeringWrite(index, &pbpeering.Peering{ Name: name, @@ -666,6 +743,7 @@ func newStateStore(t *testing.T, publisher *stream.EventPublisher) *state.Store store := state.NewStateStoreWithEventPublisher(gc, publisher) require.NoError(t, publisher.RegisterHandler(state.EventTopicServiceHealth, store.ServiceHealthSnapshot)) require.NoError(t, publisher.RegisterHandler(state.EventTopicServiceHealthConnect, store.ServiceHealthSnapshot)) + require.NoError(t, publisher.RegisterHandler(state.EventTopicCARoots, store.CARootsSnapshot)) go publisher.Run(ctx) return store diff --git a/agent/rpc/peering/subscription_state.go b/agent/rpc/peering/subscription_state.go index 093f3aa6bf..29bbff967c 100644 --- a/agent/rpc/peering/subscription_state.go +++ b/agent/rpc/peering/subscription_state.go @@ -92,6 +92,9 @@ func (s *subscriptionState) cleanupEventVersions(logger hclog.Logger) { case id == meshGatewayPayloadID: keep = true + case id == caRootsPayloadID: + keep = true + case strings.HasPrefix(id, servicePayloadIDPrefix): name := strings.TrimPrefix(id, servicePayloadIDPrefix) sn := structs.ServiceNameFromString(name) @@ -136,6 +139,7 @@ type pendingEvent struct { } const ( + caRootsPayloadID = "roots" meshGatewayPayloadID = "mesh-gateway" servicePayloadIDPrefix = "service:" proxyServicePayloadIDPrefix = "proxy-service:" // TODO(peering): remove diff --git a/proto/pbpeering/types.go b/proto/pbpeering/types.go index 3e6b092e2e..23847e46e7 100644 --- a/proto/pbpeering/types.go +++ b/proto/pbpeering/types.go @@ -2,4 +2,9 @@ package pbpeering const ( TypeURLService = "type.googleapis.com/consul.api.Service" + TypeURLRoots = "type.googleapis.com/consul.api.CARoots" ) + +func KnownTypeURL(s string) bool { + return s == TypeURLService || s == TypeURLRoots +}