diff --git a/agent/consul/fsm/fsm.go b/agent/consul/fsm/fsm.go index 9dcc5f64f3..a9de91935d 100644 --- a/agent/consul/fsm/fsm.go +++ b/agent/consul/fsm/fsm.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/raft" "github.com/hashicorp/consul/agent/consul/state" + "github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/logging" ) @@ -56,6 +57,8 @@ type FSM struct { // Raft side, so doesn't need to lock this. stateLock sync.RWMutex state *state.Store + + publisher *stream.EventPublisher } // New is used to construct a new FSM with a blank state. @@ -77,6 +80,8 @@ type Deps struct { // NewStateStore will be called once when the FSM is created and again any // time Restore() is called. NewStateStore func() *state.Store + + Publisher *stream.EventPublisher } // NewFromDeps creates a new FSM from its dependencies. @@ -101,6 +106,10 @@ func NewFromDeps(deps Deps) *FSM { } fsm.chunker = raftchunking.NewChunkingFSM(fsm, nil) + + // register the streaming snapshot handlers if an event publisher was provided. + fsm.registerStreamSnapshotHandlers() + return fsm } @@ -204,12 +213,28 @@ func (c *FSM) Restore(old io.ReadCloser) error { c.stateLock.Lock() stateOld := c.state c.state = stateNew + + // Tell the EventPublisher to cycle anything watching these topics. Replacement + // of the state store means that indexes could have gone backwards and data changed. + // + // This needs to happen while holding the state lock to ensure its not racey. If we + // did this outside of the locked section closer to where we abandon the old store + // then there would be a possibility for new streams to be opened that would get + // a snapshot from the cache sourced from old data but would be receiving events + // for new data. To prevent that inconsistency we refresh the topics while holding + // the lock which ensures that any subscriptions to topics for FSM generated events + if c.deps.Publisher != nil { + c.deps.Publisher.RefreshTopic(state.EventTopicServiceHealth) + c.deps.Publisher.RefreshTopic(state.EventTopicServiceHealthConnect) + c.deps.Publisher.RefreshTopic(state.EventTopicCARoots) + } c.stateLock.Unlock() // Signal that the old state store has been abandoned. This is required // because we don't operate on it any more, we just throw it away, so // blocking queries won't see any changes and need to be woken up. stateOld.Abandon() + return nil } @@ -244,3 +269,30 @@ func ReadSnapshot(r io.Reader, handler func(header *SnapshotHeader, msg structs. } } } + +func (c *FSM) registerStreamSnapshotHandlers() { + if c.deps.Publisher == nil { + return + } + + err := c.deps.Publisher.RegisterHandler(state.EventTopicServiceHealth, func(req stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) { + return c.State().ServiceHealthSnapshot(req, buf) + }) + if err != nil { + panic(fmt.Errorf("fatal error encountered registering streaming snapshot handlers: %w", err)) + } + + err = c.deps.Publisher.RegisterHandler(state.EventTopicServiceHealthConnect, func(req stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) { + return c.State().ServiceHealthSnapshot(req, buf) + }) + if err != nil { + panic(fmt.Errorf("fatal error encountered registering streaming snapshot handlers: %w", err)) + } + + err = c.deps.Publisher.RegisterHandler(state.EventTopicCARoots, func(req stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) { + return c.State().CARootsSnapshot(req, buf) + }) + if err != nil { + panic(fmt.Errorf("fatal error encountered registering streaming snapshot handlers: %w", err)) + } +} diff --git a/agent/consul/server.go b/agent/consul/server.go index a3effba97a..401954d853 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -39,6 +39,7 @@ import ( "github.com/hashicorp/consul/agent/consul/authmethod/ssoauth" "github.com/hashicorp/consul/agent/consul/fsm" "github.com/hashicorp/consul/agent/consul/state" + "github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/consul/usagemetrics" "github.com/hashicorp/consul/agent/consul/wanfed" agentgrpc "github.com/hashicorp/consul/agent/grpc/private" @@ -343,6 +344,12 @@ type Server struct { // Manager to handle starting/stopping go routines when establishing/revoking raft leadership leaderRoutineManager *routine.Manager + // publisher is the EventPublisher to be shared amongst various server components. Events from + // modifications to the FSM, autopilot and others will flow through here. If in the future we + // need Events generated outside of the Server and all its components, then we could move + // this into the Deps struct and created it much earlier on. + publisher *stream.EventPublisher + // embedded struct to hold all the enterprise specific data EnterpriseServer } @@ -397,6 +404,16 @@ func NewServer(config *Config, flat Deps, publicGRPCServer *grpc.Server) (*Serve insecureRPCServer = rpc.NewServerWithOpts(rpc.WithServerServiceCallInterceptor(flat.GetNetRPCInterceptorFunc(recorder))) } + eventPublisher := stream.NewEventPublisher(10 * time.Second) + + fsmDeps := fsm.Deps{ + Logger: flat.Logger, + NewStateStore: func() *state.Store { + return state.NewStateStoreWithEventPublisher(gc, eventPublisher) + }, + Publisher: eventPublisher, + } + // Create server. s := &Server{ config: config, @@ -422,9 +439,12 @@ func NewServer(config *Config, flat Deps, publicGRPCServer *grpc.Server) (*Serve shutdownCh: shutdownCh, leaderRoutineManager: routine.NewManager(logger.Named(logging.Leader)), aclAuthMethodValidators: authmethod.NewCache(), - fsm: newFSMFromConfig(flat.Logger, gc, config), + fsm: fsm.NewFromDeps(fsmDeps), + publisher: eventPublisher, } + go s.publisher.Run(&lib.StopChannelContext{StopCh: s.shutdownCh}) + if s.config.ConnectMeshGatewayWANFederationEnabled { s.gatewayLocator = NewGatewayLocator( s.logger, @@ -652,6 +672,7 @@ func NewServer(config *Config, flat Deps, publicGRPCServer *grpc.Server) (*Serve // Initialize public gRPC server - register services on public gRPC server. connectca.NewServer(connectca.Config{ + Publisher: s.publisher, GetStore: func() connectca.StateStore { return s.FSM().State() }, Logger: logger.Named("grpc-api.connect-ca"), ACLResolver: plainACLResolver{s.ACLResolver}, @@ -684,21 +705,6 @@ func NewServer(config *Config, flat Deps, publicGRPCServer *grpc.Server) (*Serve return s, nil } -func newFSMFromConfig(logger hclog.Logger, gc *state.TombstoneGC, config *Config) *fsm.FSM { - deps := fsm.Deps{Logger: logger} - if config.RPCConfig.EnableStreaming { - deps.NewStateStore = func() *state.Store { - return state.NewStateStoreWithEventPublisher(gc) - } - return fsm.NewFromDeps(deps) - } - - deps.NewStateStore = func() *state.Store { - return state.NewStateStore(gc) - } - return fsm.NewFromDeps(deps) -} - func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler { register := func(srv *grpc.Server) { if config.RPCConfig.EnableStreaming { diff --git a/agent/consul/state/catalog_events.go b/agent/consul/state/catalog_events.go index 91e1bf361c..13c5c4ba0c 100644 --- a/agent/consul/state/catalog_events.go +++ b/agent/consul/state/catalog_events.go @@ -78,50 +78,48 @@ func (e EventPayloadCheckServiceNode) Subject() stream.Subject { // serviceHealthSnapshot returns a stream.SnapshotFunc that provides a snapshot // of stream.Events that describe the current state of a service health query. -func serviceHealthSnapshot(db ReadDB, topic stream.Topic) stream.SnapshotFunc { - return func(req stream.SubscribeRequest, buf stream.SnapshotAppender) (index uint64, err error) { - tx := db.ReadTxn() - defer tx.Abort() +func (s *Store) ServiceHealthSnapshot(req stream.SubscribeRequest, buf stream.SnapshotAppender) (index uint64, err error) { + tx := s.db.ReadTxn() + defer tx.Abort() - connect := topic == topicServiceHealthConnect + connect := req.Topic == EventTopicServiceHealthConnect - subject, ok := req.Subject.(EventSubjectService) - if !ok { - return 0, fmt.Errorf("expected SubscribeRequest.Subject to be a: state.EventSubjectService, was a: %T", req.Subject) - } - - idx, nodes, err := checkServiceNodesTxn(tx, nil, subject.Key, connect, &subject.EnterpriseMeta) - if err != nil { - return 0, err - } - - for i := range nodes { - n := nodes[i] - event := stream.Event{ - Index: idx, - Topic: topic, - Payload: EventPayloadCheckServiceNode{ - Op: pbsubscribe.CatalogOp_Register, - Value: &n, - }, - } - - if !connect { - // append each event as a separate item so that they can be serialized - // separately, to prevent the encoding of one massive message. - buf.Append([]stream.Event{event}) - continue - } - - events, err := connectEventsByServiceKind(tx, event) - if err != nil { - return idx, err - } - buf.Append(events) - } - - return idx, err + subject, ok := req.Subject.(EventSubjectService) + if !ok { + return 0, fmt.Errorf("expected SubscribeRequest.Subject to be a: state.EventSubjectService, was a: %T", req.Subject) } + + idx, nodes, err := checkServiceNodesTxn(tx, nil, subject.Key, connect, &subject.EnterpriseMeta) + if err != nil { + return 0, err + } + + for i := range nodes { + n := nodes[i] + event := stream.Event{ + Index: idx, + Topic: req.Topic, + Payload: EventPayloadCheckServiceNode{ + Op: pbsubscribe.CatalogOp_Register, + Value: &n, + }, + } + + if !connect { + // append each event as a separate item so that they can be serialized + // separately, to prevent the encoding of one massive message. + buf.Append([]stream.Event{event}) + continue + } + + events, err := connectEventsByServiceKind(tx, event) + if err != nil { + return idx, err + } + buf.Append(events) + } + + return idx, err } // TODO: this could use NodeServiceQuery @@ -355,7 +353,7 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event for _, sn := range nodes { e := newServiceHealthEventDeregister(changes.Index, sn) - e.Topic = topicServiceHealthConnect + e.Topic = EventTopicServiceHealthConnect payload := e.Payload.(EventPayloadCheckServiceNode) payload.overrideKey = serviceName.Name if gatewayName.EnterpriseMeta.NamespaceOrDefault() != serviceName.EnterpriseMeta.NamespaceOrDefault() { @@ -388,7 +386,7 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event return nil, err } - e.Topic = topicServiceHealthConnect + e.Topic = EventTopicServiceHealthConnect payload := e.Payload.(EventPayloadCheckServiceNode) payload.overrideKey = serviceName.Name if gatewayName.EnterpriseMeta.NamespaceOrDefault() != serviceName.EnterpriseMeta.NamespaceOrDefault() { @@ -426,7 +424,7 @@ func isConnectProxyDestinationServiceChange(idx uint64, before, after *structs.S } e := newServiceHealthEventDeregister(idx, before) - e.Topic = topicServiceHealthConnect + e.Topic = EventTopicServiceHealthConnect payload := e.Payload.(EventPayloadCheckServiceNode) payload.overrideKey = payload.Value.Service.Proxy.DestinationServiceName e.Payload = payload @@ -467,7 +465,7 @@ func serviceHealthToConnectEvents( ) ([]stream.Event, error) { var result []stream.Event for _, event := range events { - if event.Topic != topicServiceHealth { // event.Topic == topicServiceHealthConnect + if event.Topic != EventTopicServiceHealth { // event.Topic == topicServiceHealthConnect // Skip non-health or any events already emitted to Connect topic continue } @@ -490,7 +488,7 @@ func connectEventsByServiceKind(tx ReadTxn, origEvent stream.Event) ([]stream.Ev } event := origEvent // shallow copy the event - event.Topic = topicServiceHealthConnect + event.Topic = EventTopicServiceHealthConnect if node.Service.Connect.Native { return []stream.Event{event}, nil @@ -527,7 +525,7 @@ func connectEventsByServiceKind(tx ReadTxn, origEvent stream.Event) ([]stream.Ev } func copyEventForService(event stream.Event, service structs.ServiceName) stream.Event { - event.Topic = topicServiceHealthConnect + event.Topic = EventTopicServiceHealthConnect payload := event.Payload.(EventPayloadCheckServiceNode) payload.overrideKey = service.Name if payload.Value.Service.EnterpriseMeta.NamespaceOrDefault() != service.EnterpriseMeta.NamespaceOrDefault() { @@ -666,7 +664,7 @@ func newServiceHealthEventRegister( Checks: checks, } return stream.Event{ - Topic: topicServiceHealth, + Topic: EventTopicServiceHealth, Index: idx, Payload: EventPayloadCheckServiceNode{ Op: pbsubscribe.CatalogOp_Register, @@ -697,7 +695,7 @@ func newServiceHealthEventDeregister(idx uint64, sn *structs.ServiceNode) stream } return stream.Event{ - Topic: topicServiceHealth, + Topic: EventTopicServiceHealth, Index: idx, Payload: EventPayloadCheckServiceNode{ Op: pbsubscribe.CatalogOp_Deregister, diff --git a/agent/consul/state/catalog_events_test.go b/agent/consul/state/catalog_events_test.go index b85ea5f76d..1f6f6d885a 100644 --- a/agent/consul/state/catalog_events_test.go +++ b/agent/consul/state/catalog_events_test.go @@ -70,11 +70,10 @@ func TestServiceHealthSnapshot(t *testing.T) { err = store.EnsureRegistration(counter.Next(), testServiceRegistration(t, "web", regNode2)) require.NoError(t, err) - fn := serviceHealthSnapshot((*readDB)(store.db.db), topicServiceHealth) buf := &snapshotAppender{} - req := stream.SubscribeRequest{Subject: EventSubjectService{Key: "web"}} + req := stream.SubscribeRequest{Topic: EventTopicServiceHealth, Subject: EventSubjectService{Key: "web"}} - idx, err := fn(req, buf) + idx, err := store.ServiceHealthSnapshot(req, buf) require.NoError(t, err) require.Equal(t, counter.Last(), idx) @@ -147,11 +146,10 @@ func TestServiceHealthSnapshot_ConnectTopic(t *testing.T) { err = store.EnsureRegistration(counter.Next(), testServiceRegistration(t, "tgate1", regTerminatingGateway)) require.NoError(t, err) - fn := serviceHealthSnapshot((*readDB)(store.db.db), topicServiceHealthConnect) buf := &snapshotAppender{} - req := stream.SubscribeRequest{Subject: EventSubjectService{Key: "web"}, Topic: topicServiceHealthConnect} + req := stream.SubscribeRequest{Subject: EventSubjectService{Key: "web"}, Topic: EventTopicServiceHealthConnect} - idx, err := fn(req, buf) + idx, err := store.ServiceHealthSnapshot(req, buf) require.NoError(t, err) require.Equal(t, counter.Last(), idx) @@ -1743,7 +1741,7 @@ func evServiceTermingGateway(name string) func(e *stream.Event) error { } } - if e.Topic == topicServiceHealthConnect { + if e.Topic == EventTopicServiceHealthConnect { payload := e.Payload.(EventPayloadCheckServiceNode) payload.overrideKey = name e.Payload = payload @@ -2096,7 +2094,7 @@ func evConnectNative(e *stream.Event) error { // depending on which topic they are published to and they determine this from // the event. func evConnectTopic(e *stream.Event) error { - e.Topic = topicServiceHealthConnect + e.Topic = EventTopicServiceHealthConnect return nil } @@ -2135,7 +2133,7 @@ func evSidecar(e *stream.Event) error { csn.Checks[1].ServiceName = svc + "_sidecar_proxy" } - if e.Topic == topicServiceHealthConnect { + if e.Topic == EventTopicServiceHealthConnect { payload := e.Payload.(EventPayloadCheckServiceNode) payload.overrideKey = svc e.Payload = payload @@ -2238,7 +2236,7 @@ func evRenameService(e *stream.Event) error { taggedAddr.Address = "240.0.0.2" csn.Service.TaggedAddresses[structs.TaggedAddressVirtualIP] = taggedAddr - if e.Topic == topicServiceHealthConnect { + if e.Topic == EventTopicServiceHealthConnect { payload := e.Payload.(EventPayloadCheckServiceNode) payload.overrideKey = csn.Service.Proxy.DestinationServiceName e.Payload = payload @@ -2350,7 +2348,7 @@ func newTestEventServiceHealthRegister(index uint64, nodeNum int, svc string) st addr := fmt.Sprintf("10.10.%d.%d", nodeNum/256, nodeNum%256) return stream.Event{ - Topic: topicServiceHealth, + Topic: EventTopicServiceHealth, Index: index, Payload: EventPayloadCheckServiceNode{ Op: pbsubscribe.CatalogOp_Register, @@ -2421,7 +2419,7 @@ func newTestEventServiceHealthRegister(index uint64, nodeNum int, svc string) st // adding too many options to callers. func newTestEventServiceHealthDeregister(index uint64, nodeNum int, svc string) stream.Event { return stream.Event{ - Topic: topicServiceHealth, + Topic: EventTopicServiceHealth, Index: index, Payload: EventPayloadCheckServiceNode{ Op: pbsubscribe.CatalogOp_Deregister, diff --git a/agent/consul/state/connect_ca_events.go b/agent/consul/state/connect_ca_events.go index c6bd135be0..6a0bdb9744 100644 --- a/agent/consul/state/connect_ca_events.go +++ b/agent/consul/state/connect_ca_events.go @@ -65,23 +65,21 @@ func caRootsChangeEvents(tx ReadTxn, changes Changes) ([]stream.Event, error) { // caRootsSnapshot returns a stream.SnapshotFunc that provides a snapshot of // the current active list of CA Roots. -func caRootsSnapshot(db ReadDB) stream.SnapshotFunc { - return func(_ stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) { - tx := db.ReadTxn() - defer tx.Abort() +func (s *Store) CARootsSnapshot(_ stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) { + tx := s.db.ReadTxn() + defer tx.Abort() - idx, roots, err := caRootsTxn(tx, nil) - if err != nil { - return 0, err - } - - buf.Append([]stream.Event{ - { - Topic: EventTopicCARoots, - Index: idx, - Payload: EventPayloadCARoots{CARoots: roots}, - }, - }) - return idx, nil + idx, roots, err := caRootsTxn(tx, nil) + if err != nil { + return 0, err } + + buf.Append([]stream.Event{ + { + Topic: EventTopicCARoots, + Index: idx, + Payload: EventPayloadCARoots{CARoots: roots}, + }, + }) + return idx, nil } diff --git a/agent/consul/state/connect_ca_events_test.go b/agent/consul/state/connect_ca_events_test.go index 9651e2a470..b5062340a1 100644 --- a/agent/consul/state/connect_ca_events_test.go +++ b/agent/consul/state/connect_ca_events_test.go @@ -51,14 +51,13 @@ func TestCARootsEvents(t *testing.T) { func TestCARootsSnapshot(t *testing.T) { store := testStateStore(t) - fn := caRootsSnapshot((*readDB)(store.db.db)) var req stream.SubscribeRequest t.Run("no roots", func(t *testing.T) { buf := &snapshotAppender{} - idx, err := fn(req, buf) + idx, err := store.CARootsSnapshot(req, buf) require.NoError(t, err) require.Equal(t, uint64(0), idx) @@ -77,7 +76,7 @@ func TestCARootsSnapshot(t *testing.T) { _, err := store.CARootSetCAS(1, 0, structs.CARoots{root}) require.NoError(t, err) - idx, err := fn(req, buf) + idx, err := store.CARootsSnapshot(req, buf) require.NoError(t, err) require.Equal(t, uint64(1), idx) diff --git a/agent/consul/state/memdb.go b/agent/consul/state/memdb.go index 936375eb4d..3edca1438f 100644 --- a/agent/consul/state/memdb.go +++ b/agent/consul/state/memdb.go @@ -1,7 +1,6 @@ package state import ( - "context" "fmt" "github.com/hashicorp/go-memdb" @@ -58,7 +57,7 @@ type changeTrackerDB struct { type EventPublisher interface { Publish([]stream.Event) - Run(context.Context) + RegisterHandler(stream.Topic, stream.SnapshotFunc) error Subscribe(*stream.SubscribeRequest) (*stream.Subscription, error) } @@ -179,8 +178,8 @@ func (db *readDB) ReadTxn() AbortTxn { } var ( - topicServiceHealth = pbsubscribe.Topic_ServiceHealth - topicServiceHealthConnect = pbsubscribe.Topic_ServiceHealthConnect + EventTopicServiceHealth = pbsubscribe.Topic_ServiceHealth + EventTopicServiceHealthConnect = pbsubscribe.Topic_ServiceHealthConnect ) func processDBChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) { @@ -200,11 +199,3 @@ func processDBChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) { } return events, nil } - -func newSnapshotHandlers(db ReadDB) stream.SnapshotHandlers { - return stream.SnapshotHandlers{ - topicServiceHealth: serviceHealthSnapshot(db, topicServiceHealth), - topicServiceHealthConnect: serviceHealthSnapshot(db, topicServiceHealthConnect), - EventTopicCARoots: caRootsSnapshot(db), - } -} diff --git a/agent/consul/state/state_store.go b/agent/consul/state/state_store.go index 39a4371efe..e795b68578 100644 --- a/agent/consul/state/state_store.go +++ b/agent/consul/state/state_store.go @@ -1,10 +1,8 @@ package state import ( - "context" "errors" "fmt" - "time" memdb "github.com/hashicorp/go-memdb" @@ -109,10 +107,6 @@ type Store struct { // abandoned (usually during a restore). This is only ever closed. abandonCh chan struct{} - // TODO: refactor abondonCh to use a context so that both can use the same - // cancel mechanism. - stopEventPublisher func() - // kvsGraveyard manages tombstones for the key value store. kvsGraveyard *Graveyard @@ -159,11 +153,10 @@ func NewStateStore(gc *TombstoneGC) *Store { panic(fmt.Sprintf("failed to create state store: %v", err)) } s := &Store{ - schema: schema, - abandonCh: make(chan struct{}), - kvsGraveyard: NewGraveyard(gc), - lockDelay: NewDelay(), - stopEventPublisher: func() {}, + schema: schema, + abandonCh: make(chan struct{}), + kvsGraveyard: NewGraveyard(gc), + lockDelay: NewDelay(), db: &changeTrackerDB{ db: db, publisher: stream.NoOpEventPublisher{}, @@ -173,24 +166,13 @@ func NewStateStore(gc *TombstoneGC) *Store { return s } -func NewStateStoreWithEventPublisher(gc *TombstoneGC) *Store { +func NewStateStoreWithEventPublisher(gc *TombstoneGC, publisher EventPublisher) *Store { store := NewStateStore(gc) - ctx, cancel := context.WithCancel(context.TODO()) - store.stopEventPublisher = cancel + store.db.publisher = publisher - pub := stream.NewEventPublisher(newSnapshotHandlers((*readDB)(store.db.db)), 10*time.Second) - store.db.publisher = pub - - go pub.Run(ctx) return store } -// EventPublisher returns the stream.EventPublisher used by the Store to -// publish events. -func (s *Store) EventPublisher() EventPublisher { - return s.db.publisher -} - // Snapshot is used to create a point-in-time snapshot of the entire db. func (s *Store) Snapshot() *Snapshot { tx := s.db.Txn(false) @@ -277,11 +259,7 @@ func (s *Store) AbandonCh() <-chan struct{} { // Abandon is used to signal that the given state store has been abandoned. // Calling this more than one time will panic. func (s *Store) Abandon() { - // Note: the order of these operations matters. Subscribers may receive on - // abandonCh to determine whether their subscription was closed because the - // store was abandoned, therefore it's important abandonCh is closed first. close(s.abandonCh) - s.stopEventPublisher() } // maxIndex is a helper used to retrieve the highest known index diff --git a/agent/consul/state/store_integration_test.go b/agent/consul/state/store_integration_test.go index 55c3059ce9..421205e142 100644 --- a/agent/consul/state/store_integration_test.go +++ b/agent/consul/state/store_integration_test.go @@ -32,7 +32,8 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - publisher := stream.NewEventPublisher(newTestSnapshotHandlers(s), 0) + publisher := stream.NewEventPublisher(0) + registerTestSnapshotHandlers(t, s, publisher) go publisher.Run(ctx) s.db.publisher = publisher sub, err := publisher.Subscribe(subscription) @@ -119,7 +120,8 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - publisher := stream.NewEventPublisher(newTestSnapshotHandlers(s), 0) + publisher := stream.NewEventPublisher(0) + registerTestSnapshotHandlers(t, s, publisher) go publisher.Run(ctx) s.db.publisher = publisher sub, err := publisher.Subscribe(subscription) @@ -240,7 +242,8 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - publisher := stream.NewEventPublisher(newTestSnapshotHandlers(s), 0) + publisher := stream.NewEventPublisher(0) + registerTestSnapshotHandlers(t, s, publisher) go publisher.Run(ctx) s.db.publisher = publisher sub, err := publisher.Subscribe(subscription) @@ -393,27 +396,29 @@ func (t topic) String() string { var topicService topic = "test-topic-service" -func newTestSnapshotHandlers(s *Store) stream.SnapshotHandlers { - return stream.SnapshotHandlers{ - topicService: func(req stream.SubscribeRequest, snap stream.SnapshotAppender) (uint64, error) { - key := req.Subject.String() +func (s *Store) topicServiceTestHandler(req stream.SubscribeRequest, snap stream.SnapshotAppender) (uint64, error) { + key := req.Subject.String() - idx, nodes, err := s.ServiceNodes(nil, key, nil) - if err != nil { - return idx, err - } - - for _, node := range nodes { - event := stream.Event{ - Topic: req.Topic, - Index: node.ModifyIndex, - Payload: nodePayload{node: node, key: key}, - } - snap.Append([]stream.Event{event}) - } - return idx, nil - }, + idx, nodes, err := s.ServiceNodes(nil, key, nil) + if err != nil { + return idx, err } + + for _, node := range nodes { + event := stream.Event{ + Topic: req.Topic, + Index: node.ModifyIndex, + Payload: nodePayload{node: node, key: key}, + } + snap.Append([]stream.Event{event}) + } + return idx, nil +} + +func registerTestSnapshotHandlers(t *testing.T, s *Store, publisher EventPublisher) { + t.Helper() + err := publisher.RegisterHandler(topicService, s.topicServiceTestHandler) + require.NoError(t, err) } type nodePayload struct { @@ -460,7 +465,8 @@ func createTokenAndWaitForACLEventPublish(t *testing.T, s *Store) *structs.ACLTo ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - publisher := stream.NewEventPublisher(newTestSnapshotHandlers(s), 0) + publisher := stream.NewEventPublisher(0) + registerTestSnapshotHandlers(t, s, publisher) go publisher.Run(ctx) s.db.publisher = publisher diff --git a/agent/consul/stream/event_publisher.go b/agent/consul/stream/event_publisher.go index 06b7b03a27..2cd0564ff7 100644 --- a/agent/consul/stream/event_publisher.go +++ b/agent/consul/stream/event_publisher.go @@ -91,7 +91,7 @@ type SnapshotAppender interface { // A goroutine is run in the background to publish events to all subscribes. // Cancelling the context will shutdown the goroutine, to free resources, // and stop all publishing. -func NewEventPublisher(handlers SnapshotHandlers, snapCacheTTL time.Duration) *EventPublisher { +func NewEventPublisher(snapCacheTTL time.Duration) *EventPublisher { e := &EventPublisher{ snapCacheTTL: snapCacheTTL, topicBuffers: make(map[topicSubject]*topicBuffer), @@ -100,12 +100,41 @@ func NewEventPublisher(handlers SnapshotHandlers, snapCacheTTL time.Duration) *E subscriptions: &subscriptions{ byToken: make(map[string]map[*SubscribeRequest]*Subscription), }, - snapshotHandlers: handlers, + snapshotHandlers: make(map[Topic]SnapshotFunc), } return e } +// RegisterHandler will register a new snapshot handler function. The expectation is +// that all handlers get registered prior to the event publisher being Run. Handler +// registration is therefore not concurrency safe and access to handlers is internally +// not synchronized. +func (e *EventPublisher) RegisterHandler(topic Topic, handler SnapshotFunc) error { + if topic.String() == "" { + return fmt.Errorf("the topic cannnot be empty") + } + + if _, found := e.snapshotHandlers[topic]; found { + return fmt.Errorf("a handler is already registered for the topic: %s", topic.String()) + } + + e.snapshotHandlers[topic] = handler + + return nil +} + +func (e *EventPublisher) RefreshTopic(topic Topic) error { + if _, found := e.snapshotHandlers[topic]; !found { + return fmt.Errorf("topic %s is not registered", topic) + } + + e.forceEvictByTopic(topic) + e.subscriptions.closeAllByTopic(topic) + + return nil +} + // Publish events to all subscribers of the event Topic. The events will be shared // with all subscriptions, so the Payload used in Event.Payload must be immutable. func (e *EventPublisher) Publish(events []Event) { @@ -196,14 +225,14 @@ func (e *EventPublisher) bufferForPublishing(key topicSubject) *eventBuffer { // When the caller is finished with the subscription for any reason, it must // call Subscription.Unsubscribe to free ACL tracking resources. func (e *EventPublisher) Subscribe(req *SubscribeRequest) (*Subscription, error) { + e.lock.Lock() + defer e.lock.Unlock() + handler, ok := e.snapshotHandlers[req.Topic] if !ok || req.Topic == nil { return nil, fmt.Errorf("unknown topic %v", req.Topic) } - e.lock.Lock() - defer e.lock.Unlock() - topicBuf := e.bufferForSubscription(req.topicSubject()) topicBuf.refs++ @@ -327,6 +356,19 @@ func (s *subscriptions) closeAll() { } } +func (s *subscriptions) closeAllByTopic(topic Topic) { + s.lock.Lock() + defer s.lock.Unlock() + + for _, byRequest := range s.byToken { + for _, sub := range byRequest { + if sub.req.Topic == topic { + sub.forceClose() + } + } + } +} + // EventPublisher.lock must be held to call this method. func (e *EventPublisher) getCachedSnapshotLocked(req *SubscribeRequest) *eventSnapshot { snap, ok := e.snapCache[req.topicSubject()] @@ -350,3 +392,15 @@ func (e *EventPublisher) setCachedSnapshotLocked(req *SubscribeRequest, snap *ev delete(e.snapCache, req.topicSubject()) }) } + +// forceEvictByTopic will remove all entries from the snapshot cache for a given topic. +// This method should be called while holding the publishers lock. +func (e *EventPublisher) forceEvictByTopic(topic Topic) { + e.lock.Lock() + for key := range e.snapCache { + if key.Topic == topic.String() { + delete(e.snapCache, key) + } + } + e.lock.Unlock() +} diff --git a/agent/consul/stream/event_publisher_test.go b/agent/consul/stream/event_publisher_test.go index c718d58538..fbd253830d 100644 --- a/agent/consul/stream/event_publisher_test.go +++ b/agent/consul/stream/event_publisher_test.go @@ -27,7 +27,8 @@ func TestEventPublisher_SubscribeWithIndex0(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - publisher := NewEventPublisher(newTestSnapshotHandlers(), 0) + publisher := NewEventPublisher(0) + registerTestSnapshotHandlers(t, publisher) go publisher.Run(ctx) sub, err := publisher.Subscribe(req) @@ -83,16 +84,18 @@ func (p simplePayload) HasReadPermission(acl.Authorizer) bool { func (p simplePayload) Subject() Subject { return stringer(p.key) } -func newTestSnapshotHandlers() SnapshotHandlers { - return SnapshotHandlers{ - testTopic: func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) { - if req.Topic != testTopic { - return 0, fmt.Errorf("unexpected topic: %v", req.Topic) - } - buf.Append([]Event{testSnapshotEvent}) - return 1, nil - }, +func registerTestSnapshotHandlers(t *testing.T, publisher *EventPublisher) { + t.Helper() + + testTopicHandler := func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) { + if req.Topic != testTopic { + return 0, fmt.Errorf("unexpected topic: %v", req.Topic) + } + buf.Append([]Event{testSnapshotEvent}) + return 1, nil } + + require.NoError(t, publisher.RegisterHandler(testTopic, testTopicHandler)) } func runSubscription(ctx context.Context, sub *Subscription) <-chan eventOrErr { @@ -143,14 +146,14 @@ func TestEventPublisher_ShutdownClosesSubscriptions(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - handlers := newTestSnapshotHandlers() fn := func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) { return 0, nil } - handlers[intTopic(22)] = fn - handlers[intTopic(33)] = fn - publisher := NewEventPublisher(handlers, time.Second) + publisher := NewEventPublisher(time.Second) + registerTestSnapshotHandlers(t, publisher) + publisher.RegisterHandler(intTopic(22), fn) + publisher.RegisterHandler(intTopic(33), fn) go publisher.Run(ctx) sub1, err := publisher.Subscribe(&SubscribeRequest{Topic: intTopic(22), Subject: SubjectNone}) @@ -190,7 +193,8 @@ func TestEventPublisher_SubscribeWithIndex0_FromCache(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) + publisher := NewEventPublisher(time.Second) + registerTestSnapshotHandlers(t, publisher) go publisher.Run(ctx) sub, err := publisher.Subscribe(req) @@ -235,7 +239,8 @@ func TestEventPublisher_SubscribeWithIndexNotZero_CanResume(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) + publisher := NewEventPublisher(time.Second) + registerTestSnapshotHandlers(t, publisher) go publisher.Run(ctx) simulateExistingSubscriber(t, publisher, req) @@ -288,7 +293,8 @@ func TestEventPublisher_SubscribeWithIndexNotZero_NewSnapshot(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - publisher := NewEventPublisher(newTestSnapshotHandlers(), 0) + publisher := NewEventPublisher(0) + registerTestSnapshotHandlers(t, publisher) go publisher.Run(ctx) // Include the same event in the topicBuffer publisher.publishEvent([]Event{testSnapshotEvent}) @@ -344,7 +350,8 @@ func TestEventPublisher_SubscribeWithIndexNotZero_NewSnapshotFromCache(t *testin ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) + publisher := NewEventPublisher(time.Second) + registerTestSnapshotHandlers(t, publisher) go publisher.Run(ctx) simulateExistingSubscriber(t, publisher, req) @@ -417,21 +424,20 @@ func TestEventPublisher_SubscribeWithIndexNotZero_NewSnapshot_WithCache(t *testi Payload: simplePayload{key: "sub-key", value: "event-3"}, } - handlers := SnapshotHandlers{ - testTopic: func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) { - if req.Topic != testTopic { - return 0, fmt.Errorf("unexpected topic: %v", req.Topic) - } - buf.Append([]Event{testSnapshotEvent}) - buf.Append([]Event{nextEvent}) - return 3, nil - }, + testTopicHandler := func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) { + if req.Topic != testTopic { + return 0, fmt.Errorf("unexpected topic: %v", req.Topic) + } + buf.Append([]Event{testSnapshotEvent}) + buf.Append([]Event{nextEvent}) + return 3, nil } ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - publisher := NewEventPublisher(handlers, time.Second) + publisher := NewEventPublisher(time.Second) + publisher.RegisterHandler(testTopic, testTopicHandler) go publisher.Run(ctx) simulateExistingSubscriber(t, publisher, req) @@ -498,7 +504,8 @@ func TestEventPublisher_Unsubscribe_ClosesSubscription(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) + publisher := NewEventPublisher(time.Second) + registerTestSnapshotHandlers(t, publisher) sub, err := publisher.Subscribe(req) require.NoError(t, err) @@ -518,7 +525,8 @@ func TestEventPublisher_Unsubscribe_FreesResourcesWhenThereAreNoSubscribers(t *t Subject: stringer("sub-key"), } - publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) + publisher := NewEventPublisher(time.Second) + registerTestSnapshotHandlers(t, publisher) sub1, err := publisher.Subscribe(req) require.NoError(t, err) diff --git a/agent/consul/stream/noop.go b/agent/consul/stream/noop.go index 1b3282dbfc..84d6a648da 100644 --- a/agent/consul/stream/noop.go +++ b/agent/consul/stream/noop.go @@ -9,6 +9,10 @@ type NoOpEventPublisher struct{} func (NoOpEventPublisher) Publish([]Event) {} +func (NoOpEventPublisher) RegisterHandler(Topic, SnapshotFunc) error { + return fmt.Errorf("stream event publisher is disabled") +} + func (NoOpEventPublisher) Run(context.Context) {} func (NoOpEventPublisher) Subscribe(*SubscribeRequest) (*Subscription, error) { diff --git a/agent/consul/subscribe_backend.go b/agent/consul/subscribe_backend.go index 94b8671f4f..bddbb2e5fa 100644 --- a/agent/consul/subscribe_backend.go +++ b/agent/consul/subscribe_backend.go @@ -31,5 +31,5 @@ func (s subscribeBackend) Forward(info structs.RPCInfo, f func(*grpc.ClientConn) } func (s subscribeBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) { - return s.srv.fsm.State().EventPublisher().Subscribe(req) + return s.srv.publisher.Subscribe(req) } diff --git a/agent/grpc/private/services/subscribe/subscribe_test.go b/agent/grpc/private/services/subscribe/subscribe_test.go index d9d8d162d0..c319590575 100644 --- a/agent/grpc/private/services/subscribe/subscribe_test.go +++ b/agent/grpc/private/services/subscribe/subscribe_test.go @@ -32,8 +32,7 @@ import ( ) func TestServer_Subscribe_KeyIsRequired(t *testing.T) { - backend, err := newTestBackend() - require.NoError(t, err) + backend := newTestBackend(t) addr := runTestServer(t, NewServer(backend, hclog.New(nil))) @@ -59,8 +58,7 @@ func TestServer_Subscribe_KeyIsRequired(t *testing.T) { } func TestServer_Subscribe_IntegrationWithBackend(t *testing.T) { - backend, err := newTestBackend() - require.NoError(t, err) + backend := newTestBackend(t) addr := runTestServer(t, NewServer(backend, hclog.New(nil))) ids := newCounter() @@ -312,6 +310,7 @@ func getEvent(t *testing.T, ch chan eventOrError) *pbsubscribe.Event { } type testBackend struct { + publisher *stream.EventPublisher store *state.Store authorizer func(token string, entMeta *acl.EnterpriseMeta) acl.Authorizer forwardConn *gogrpc.ClientConn @@ -333,19 +332,33 @@ func (b testBackend) Forward(_ structs.RPCInfo, fn func(*gogrpc.ClientConn) erro } func (b testBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) { - return b.store.EventPublisher().Subscribe(req) + return b.publisher.Subscribe(req) } -func newTestBackend() (*testBackend, error) { +func newTestBackend(t *testing.T) *testBackend { + t.Helper() gc, err := state.NewTombstoneGC(time.Second, time.Millisecond) - if err != nil { - return nil, err - } - store := state.NewStateStoreWithEventPublisher(gc) + require.NoError(t, err) + + publisher := stream.NewEventPublisher(10 * time.Second) + + store := state.NewStateStoreWithEventPublisher(gc, publisher) + + // normally the handlers are registered on the FSM as state stores may come + // and go during snapshot restores. For the purposes of this test backend though we + // just register them directly to + require.NoError(t, publisher.RegisterHandler(state.EventTopicCARoots, store.CARootsSnapshot)) + require.NoError(t, publisher.RegisterHandler(state.EventTopicServiceHealth, store.ServiceHealthSnapshot)) + require.NoError(t, publisher.RegisterHandler(state.EventTopicServiceHealthConnect, store.ServiceHealthSnapshot)) + + ctx, cancel := context.WithCancel(context.Background()) + go publisher.Run(ctx) + t.Cleanup(cancel) + allowAll := func(string, *acl.EnterpriseMeta) acl.Authorizer { return acl.AllowAll() } - return &testBackend{store: store, authorizer: allowAll}, nil + return &testBackend{publisher: publisher, store: store, authorizer: allowAll} } var _ Backend = (*testBackend)(nil) @@ -409,12 +422,10 @@ func raftIndex(ids *counter, created, modified string) *pbcommon.RaftIndex { } func TestServer_Subscribe_IntegrationWithBackend_ForwardToDC(t *testing.T) { - backendLocal, err := newTestBackend() - require.NoError(t, err) + backendLocal := newTestBackend(t) addrLocal := runTestServer(t, NewServer(backendLocal, hclog.New(nil))) - backendRemoteDC, err := newTestBackend() - require.NoError(t, err) + backendRemoteDC := newTestBackend(t) srvRemoteDC := NewServer(backendRemoteDC, hclog.New(nil)) addrRemoteDC := runTestServer(t, srvRemoteDC) @@ -642,8 +653,7 @@ func TestServer_Subscribe_IntegrationWithBackend_FilterEventsByACLToken(t *testi t.Skip("too slow for -short run") } - backend, err := newTestBackend() - require.NoError(t, err) + backend := newTestBackend(t) addr := runTestServer(t, NewServer(backend, hclog.New(nil))) token := "this-token-is-good" @@ -839,8 +849,7 @@ node "node1" { } func TestServer_Subscribe_IntegrationWithBackend_ACLUpdate(t *testing.T) { - backend, err := newTestBackend() - require.NoError(t, err) + backend := newTestBackend(t) addr := runTestServer(t, NewServer(backend, hclog.New(nil))) token := "this-token-is-good" @@ -1100,12 +1109,12 @@ func newPayloadEvents(items ...stream.Event) *stream.PayloadEvents { func newEventFromSubscription(t *testing.T, index uint64) stream.Event { t.Helper() - handlers := map[stream.Topic]stream.SnapshotFunc{ - pbsubscribe.Topic_ServiceHealthConnect: func(stream.SubscribeRequest, stream.SnapshotAppender) (index uint64, err error) { - return 1, nil - }, + serviceHealthConnectHandler := func(stream.SubscribeRequest, stream.SnapshotAppender) (index uint64, err error) { + return 1, nil } - ep := stream.NewEventPublisher(handlers, 0) + + ep := stream.NewEventPublisher(0) + ep.RegisterHandler(pbsubscribe.Topic_ServiceHealthConnect, serviceHealthConnectHandler) req := &stream.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealthConnect, Subject: stream.SubjectNone, Index: index} sub, err := ep.Subscribe(req) diff --git a/agent/grpc/public/services/connectca/server.go b/agent/grpc/public/services/connectca/server.go index 002f8e3448..86edfdb545 100644 --- a/agent/grpc/public/services/connectca/server.go +++ b/agent/grpc/public/services/connectca/server.go @@ -7,7 +7,7 @@ import ( "github.com/hashicorp/go-memdb" "github.com/hashicorp/consul/acl" - "github.com/hashicorp/consul/agent/consul/state" + "github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/proto-public/pbconnectca" ) @@ -17,13 +17,17 @@ type Server struct { } type Config struct { + Publisher EventPublisher GetStore func() StateStore Logger hclog.Logger ACLResolver ACLResolver } +type EventPublisher interface { + Subscribe(*stream.SubscribeRequest) (*stream.Subscription, error) +} + type StateStore interface { - EventPublisher() state.EventPublisher CAConfig(memdb.WatchSet) (uint64, *structs.CAConfiguration, error) AbandonCh() <-chan struct{} } diff --git a/agent/grpc/public/services/connectca/server_test.go b/agent/grpc/public/services/connectca/server_test.go index 6a4d42fa0f..e74b7c0946 100644 --- a/agent/grpc/public/services/connectca/server_test.go +++ b/agent/grpc/public/services/connectca/server_test.go @@ -3,6 +3,7 @@ package connectca import ( "context" "net" + "sync" "testing" "time" @@ -10,16 +11,66 @@ import ( "google.golang.org/grpc" "github.com/hashicorp/consul/agent/consul/state" + "github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/proto-public/pbconnectca" ) -func testStateStore(t *testing.T) *state.Store { +func testStateStore(t *testing.T, publisher state.EventPublisher) *state.Store { t.Helper() gc, err := state.NewTombstoneGC(time.Second, time.Millisecond) require.NoError(t, err) - return state.NewStateStoreWithEventPublisher(gc) + return state.NewStateStoreWithEventPublisher(gc, publisher) +} + +type FakeFSM struct { + lock sync.Mutex + store *state.Store + publisher *stream.EventPublisher +} + +func newFakeFSM(t *testing.T, publisher *stream.EventPublisher) *FakeFSM { + t.Helper() + + store := testStateStore(t, publisher) + + fsm := FakeFSM{store: store, publisher: publisher} + + // register handlers + publisher.RegisterHandler(state.EventTopicCARoots, func(req stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) { + return fsm.GetStore().CARootsSnapshot(req, buf) + }) + + return &fsm +} + +func (f *FakeFSM) GetStore() *state.Store { + f.lock.Lock() + defer f.lock.Unlock() + return f.store +} + +func (f *FakeFSM) ReplaceStore(store *state.Store) { + f.lock.Lock() + defer f.lock.Unlock() + oldStore := f.store + f.store = store + oldStore.Abandon() + f.publisher.RefreshTopic(state.EventTopicCARoots) +} + +func setupFSMAndPublisher(t *testing.T) (*FakeFSM, state.EventPublisher) { + t.Helper() + publisher := stream.NewEventPublisher(10 * time.Second) + + fsm := newFakeFSM(t, publisher) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + go publisher.Run(ctx) + + return fsm, publisher } func testClient(t *testing.T, server *Server) pbconnectca.ConnectCAServiceClient { diff --git a/agent/grpc/public/services/connectca/watch_roots.go b/agent/grpc/public/services/connectca/watch_roots.go index eeaf2d8c8c..7a7430783f 100644 --- a/agent/grpc/public/services/connectca/watch_roots.go +++ b/agent/grpc/public/services/connectca/watch_roots.go @@ -68,7 +68,7 @@ func (s *Server) serveRoots( } // Start the subscription. - sub, err := store.EventPublisher().Subscribe(&stream.SubscribeRequest{ + sub, err := s.Publisher.Subscribe(&stream.SubscribeRequest{ Topic: state.EventTopicCARoots, Subject: stream.SubjectNone, Token: token, diff --git a/agent/grpc/public/services/connectca/watch_roots_test.go b/agent/grpc/public/services/connectca/watch_roots_test.go index d650a4d132..7bce07e1a1 100644 --- a/agent/grpc/public/services/connectca/watch_roots_test.go +++ b/agent/grpc/public/services/connectca/watch_roots_test.go @@ -13,7 +13,6 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" - "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-uuid" "github.com/hashicorp/consul/acl" @@ -22,19 +21,20 @@ import ( "github.com/hashicorp/consul/agent/grpc/public/testutils" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/proto-public/pbconnectca" + "github.com/hashicorp/consul/sdk/testutil" ) const testACLToken = "acl-token" func TestWatchRoots_Success(t *testing.T) { - store := testStateStore(t) + fsm, publisher := setupFSMAndPublisher(t) // Set the initial roots and CA configuration. rootA := connect.TestCA(t, nil) - _, err := store.CARootSetCAS(1, 0, structs.CARoots{rootA}) + _, err := fsm.GetStore().CARootSetCAS(1, 0, structs.CARoots{rootA}) require.NoError(t, err) - err = store.CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-id"}) + err = fsm.GetStore().CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-id"}) require.NoError(t, err) // Mock the ACL Resolver to return an authorizer with `service:write`. @@ -45,8 +45,9 @@ func TestWatchRoots_Success(t *testing.T) { ctx := public.ContextWithToken(context.Background(), testACLToken) server := NewServer(Config{ - GetStore: func() StateStore { return store }, - Logger: hclog.NewNullLogger(), + Publisher: publisher, + GetStore: func() StateStore { return fsm.GetStore() }, + Logger: testutil.Logger(t), ACLResolver: aclResolver, }) @@ -65,7 +66,7 @@ func TestWatchRoots_Success(t *testing.T) { // Rotate the roots. rootB := connect.TestCA(t, nil) - _, err = store.CARootSetCAS(2, 1, structs.CARoots{rootB}) + _, err = fsm.GetStore().CARootSetCAS(2, 1, structs.CARoots{rootB}) require.NoError(t, err) // Expect another event containing the new roots. @@ -77,10 +78,10 @@ func TestWatchRoots_Success(t *testing.T) { } func TestWatchRoots_InvalidACLToken(t *testing.T) { - store := testStateStore(t) + fsm, publisher := setupFSMAndPublisher(t) // Set the initial CA configuration. - err := store.CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-id"}) + err := fsm.GetStore().CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-id"}) require.NoError(t, err) // Mock the ACL resolver to return ErrNotFound. @@ -91,8 +92,9 @@ func TestWatchRoots_InvalidACLToken(t *testing.T) { ctx := public.ContextWithToken(context.Background(), testACLToken) server := NewServer(Config{ - GetStore: func() StateStore { return store }, - Logger: hclog.NewNullLogger(), + Publisher: publisher, + GetStore: func() StateStore { return fsm.GetStore() }, + Logger: testutil.Logger(t), ACLResolver: aclResolver, }) @@ -108,14 +110,14 @@ func TestWatchRoots_InvalidACLToken(t *testing.T) { } func TestWatchRoots_ACLTokenInvalidated(t *testing.T) { - store := testStateStore(t) + fsm, publisher := setupFSMAndPublisher(t) // Set the initial roots and CA configuration. rootA := connect.TestCA(t, nil) - _, err := store.CARootSetCAS(1, 0, structs.CARoots{rootA}) + _, err := fsm.GetStore().CARootSetCAS(1, 0, structs.CARoots{rootA}) require.NoError(t, err) - err = store.CASetConfig(2, &structs.CAConfiguration{ClusterID: "cluster-id"}) + err = fsm.GetStore().CASetConfig(2, &structs.CAConfiguration{ClusterID: "cluster-id"}) require.NoError(t, err) // Mock the ACL Resolver to return an authorizer with `service:write` the @@ -127,8 +129,9 @@ func TestWatchRoots_ACLTokenInvalidated(t *testing.T) { ctx := public.ContextWithToken(context.Background(), testACLToken) server := NewServer(Config{ - GetStore: func() StateStore { return store }, - Logger: hclog.NewNullLogger(), + Publisher: publisher, + GetStore: func() StateStore { return fsm.GetStore() }, + Logger: testutil.Logger(t), ACLResolver: aclResolver, }) @@ -144,7 +147,7 @@ func TestWatchRoots_ACLTokenInvalidated(t *testing.T) { // Update the ACL token to cause the subscription to be force-closed. accessorID, err := uuid.GenerateUUID() require.NoError(t, err) - err = store.ACLTokenSet(1, &structs.ACLToken{ + err = fsm.GetStore().ACLTokenSet(1, &structs.ACLToken{ AccessorID: accessorID, SecretID: testACLToken, }) @@ -152,7 +155,7 @@ func TestWatchRoots_ACLTokenInvalidated(t *testing.T) { // Update the roots. rootB := connect.TestCA(t, nil) - _, err = store.CARootSetCAS(3, 1, structs.CARoots{rootB}) + _, err = fsm.GetStore().CARootSetCAS(3, 1, structs.CARoots{rootB}) require.NoError(t, err) // Expect the stream to remain open and to receive the new roots. @@ -163,7 +166,7 @@ func TestWatchRoots_ACLTokenInvalidated(t *testing.T) { Return(acl.DenyAll(), nil) // Update the ACL token to cause the subscription to be force-closed. - err = store.ACLTokenSet(1, &structs.ACLToken{ + err = fsm.GetStore().ACLTokenSet(1, &structs.ACLToken{ AccessorID: accessorID, SecretID: testACLToken, }) @@ -175,14 +178,14 @@ func TestWatchRoots_ACLTokenInvalidated(t *testing.T) { } func TestWatchRoots_StateStoreAbandoned(t *testing.T) { - storeA := testStateStore(t) + fsm, publisher := setupFSMAndPublisher(t) // Set the initial roots and CA configuration. rootA := connect.TestCA(t, nil) - _, err := storeA.CARootSetCAS(1, 0, structs.CARoots{rootA}) + _, err := fsm.GetStore().CARootSetCAS(1, 0, structs.CARoots{rootA}) require.NoError(t, err) - err = storeA.CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-a"}) + err = fsm.GetStore().CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-a"}) require.NoError(t, err) // Mock the ACL Resolver to return an authorizer with `service:write`. @@ -193,8 +196,9 @@ func TestWatchRoots_StateStoreAbandoned(t *testing.T) { ctx := public.ContextWithToken(context.Background(), testACLToken) server := NewServer(Config{ - GetStore: func() StateStore { return storeA }, - Logger: hclog.NewNullLogger(), + Publisher: publisher, + GetStore: func() StateStore { return fsm.GetStore() }, + Logger: testutil.Logger(t), ACLResolver: aclResolver, }) @@ -208,7 +212,7 @@ func TestWatchRoots_StateStoreAbandoned(t *testing.T) { mustGetRoots(t, rspCh) // Simulate a snapshot restore. - storeB := testStateStore(t) + storeB := testStateStore(t, publisher) rootB := connect.TestCA(t, nil) _, err = storeB.CARootSetCAS(1, 0, structs.CARoots{rootB}) @@ -217,9 +221,7 @@ func TestWatchRoots_StateStoreAbandoned(t *testing.T) { err = storeB.CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-b"}) require.NoError(t, err) - server.GetStore = func() StateStore { return storeB } - - storeA.Abandon() + fsm.ReplaceStore(storeB) // Expect to get the new store's roots. newRoots := mustGetRoots(t, rspCh) diff --git a/agent/submatview/store_integration_test.go b/agent/submatview/store_integration_test.go index 49cb67677a..e8247b8185 100644 --- a/agent/submatview/store_integration_test.go +++ b/agent/submatview/store_integration_test.go @@ -43,10 +43,8 @@ func TestStore_IntegrationWithBackend(t *testing.T) { } sh := snapshotHandler{producers: producers} - handlers := map[stream.Topic]stream.SnapshotFunc{ - pbsubscribe.Topic_ServiceHealth: sh.Snapshot, - } - pub := stream.NewEventPublisher(handlers, 10*time.Millisecond) + pub := stream.NewEventPublisher(10 * time.Millisecond) + pub.RegisterHandler(pbsubscribe.Topic_ServiceHealth, sh.Snapshot) ctx, cancel := context.WithCancel(context.Background()) defer cancel()