Move to using a shared EventPublisher (#12673)

Previously we had 1 EventPublisher per state.Store. When a state store was closed/abandoned such as during a consul snapshot restore, this had the behavior of force closing subscriptions for that topic and evicting event snapshots from the cache.

The intention of this commit is to keep all that behavior. To that end, the shared EventPublisher now supports the ability to refresh a topic. That will perform the force close + eviction. The FSM upon abandoning the previous state.Store will call RefreshTopic for all the topics with events generated by the state.Store.
This commit is contained in:
Matt Keeler 2022-04-12 09:47:42 -04:00 committed by GitHub
parent a148ae660f
commit 8bad5105b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 413 additions and 257 deletions

View File

@ -12,6 +12,7 @@ import (
"github.com/hashicorp/raft" "github.com/hashicorp/raft"
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/logging" "github.com/hashicorp/consul/logging"
) )
@ -56,6 +57,8 @@ type FSM struct {
// Raft side, so doesn't need to lock this. // Raft side, so doesn't need to lock this.
stateLock sync.RWMutex stateLock sync.RWMutex
state *state.Store state *state.Store
publisher *stream.EventPublisher
} }
// New is used to construct a new FSM with a blank state. // New is used to construct a new FSM with a blank state.
@ -77,6 +80,8 @@ type Deps struct {
// NewStateStore will be called once when the FSM is created and again any // NewStateStore will be called once when the FSM is created and again any
// time Restore() is called. // time Restore() is called.
NewStateStore func() *state.Store NewStateStore func() *state.Store
Publisher *stream.EventPublisher
} }
// NewFromDeps creates a new FSM from its dependencies. // NewFromDeps creates a new FSM from its dependencies.
@ -101,6 +106,10 @@ func NewFromDeps(deps Deps) *FSM {
} }
fsm.chunker = raftchunking.NewChunkingFSM(fsm, nil) fsm.chunker = raftchunking.NewChunkingFSM(fsm, nil)
// register the streaming snapshot handlers if an event publisher was provided.
fsm.registerStreamSnapshotHandlers()
return fsm return fsm
} }
@ -204,12 +213,28 @@ func (c *FSM) Restore(old io.ReadCloser) error {
c.stateLock.Lock() c.stateLock.Lock()
stateOld := c.state stateOld := c.state
c.state = stateNew c.state = stateNew
// Tell the EventPublisher to cycle anything watching these topics. Replacement
// of the state store means that indexes could have gone backwards and data changed.
//
// This needs to happen while holding the state lock to ensure its not racey. If we
// did this outside of the locked section closer to where we abandon the old store
// then there would be a possibility for new streams to be opened that would get
// a snapshot from the cache sourced from old data but would be receiving events
// for new data. To prevent that inconsistency we refresh the topics while holding
// the lock which ensures that any subscriptions to topics for FSM generated events
if c.deps.Publisher != nil {
c.deps.Publisher.RefreshTopic(state.EventTopicServiceHealth)
c.deps.Publisher.RefreshTopic(state.EventTopicServiceHealthConnect)
c.deps.Publisher.RefreshTopic(state.EventTopicCARoots)
}
c.stateLock.Unlock() c.stateLock.Unlock()
// Signal that the old state store has been abandoned. This is required // Signal that the old state store has been abandoned. This is required
// because we don't operate on it any more, we just throw it away, so // because we don't operate on it any more, we just throw it away, so
// blocking queries won't see any changes and need to be woken up. // blocking queries won't see any changes and need to be woken up.
stateOld.Abandon() stateOld.Abandon()
return nil return nil
} }
@ -244,3 +269,30 @@ func ReadSnapshot(r io.Reader, handler func(header *SnapshotHeader, msg structs.
} }
} }
} }
func (c *FSM) registerStreamSnapshotHandlers() {
if c.deps.Publisher == nil {
return
}
err := c.deps.Publisher.RegisterHandler(state.EventTopicServiceHealth, func(req stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) {
return c.State().ServiceHealthSnapshot(req, buf)
})
if err != nil {
panic(fmt.Errorf("fatal error encountered registering streaming snapshot handlers: %w", err))
}
err = c.deps.Publisher.RegisterHandler(state.EventTopicServiceHealthConnect, func(req stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) {
return c.State().ServiceHealthSnapshot(req, buf)
})
if err != nil {
panic(fmt.Errorf("fatal error encountered registering streaming snapshot handlers: %w", err))
}
err = c.deps.Publisher.RegisterHandler(state.EventTopicCARoots, func(req stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) {
return c.State().CARootsSnapshot(req, buf)
})
if err != nil {
panic(fmt.Errorf("fatal error encountered registering streaming snapshot handlers: %w", err))
}
}

View File

@ -39,6 +39,7 @@ import (
"github.com/hashicorp/consul/agent/consul/authmethod/ssoauth" "github.com/hashicorp/consul/agent/consul/authmethod/ssoauth"
"github.com/hashicorp/consul/agent/consul/fsm" "github.com/hashicorp/consul/agent/consul/fsm"
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/consul/usagemetrics" "github.com/hashicorp/consul/agent/consul/usagemetrics"
"github.com/hashicorp/consul/agent/consul/wanfed" "github.com/hashicorp/consul/agent/consul/wanfed"
agentgrpc "github.com/hashicorp/consul/agent/grpc/private" agentgrpc "github.com/hashicorp/consul/agent/grpc/private"
@ -343,6 +344,12 @@ type Server struct {
// Manager to handle starting/stopping go routines when establishing/revoking raft leadership // Manager to handle starting/stopping go routines when establishing/revoking raft leadership
leaderRoutineManager *routine.Manager leaderRoutineManager *routine.Manager
// publisher is the EventPublisher to be shared amongst various server components. Events from
// modifications to the FSM, autopilot and others will flow through here. If in the future we
// need Events generated outside of the Server and all its components, then we could move
// this into the Deps struct and created it much earlier on.
publisher *stream.EventPublisher
// embedded struct to hold all the enterprise specific data // embedded struct to hold all the enterprise specific data
EnterpriseServer EnterpriseServer
} }
@ -397,6 +404,16 @@ func NewServer(config *Config, flat Deps, publicGRPCServer *grpc.Server) (*Serve
insecureRPCServer = rpc.NewServerWithOpts(rpc.WithServerServiceCallInterceptor(flat.GetNetRPCInterceptorFunc(recorder))) insecureRPCServer = rpc.NewServerWithOpts(rpc.WithServerServiceCallInterceptor(flat.GetNetRPCInterceptorFunc(recorder)))
} }
eventPublisher := stream.NewEventPublisher(10 * time.Second)
fsmDeps := fsm.Deps{
Logger: flat.Logger,
NewStateStore: func() *state.Store {
return state.NewStateStoreWithEventPublisher(gc, eventPublisher)
},
Publisher: eventPublisher,
}
// Create server. // Create server.
s := &Server{ s := &Server{
config: config, config: config,
@ -422,9 +439,12 @@ func NewServer(config *Config, flat Deps, publicGRPCServer *grpc.Server) (*Serve
shutdownCh: shutdownCh, shutdownCh: shutdownCh,
leaderRoutineManager: routine.NewManager(logger.Named(logging.Leader)), leaderRoutineManager: routine.NewManager(logger.Named(logging.Leader)),
aclAuthMethodValidators: authmethod.NewCache(), aclAuthMethodValidators: authmethod.NewCache(),
fsm: newFSMFromConfig(flat.Logger, gc, config), fsm: fsm.NewFromDeps(fsmDeps),
publisher: eventPublisher,
} }
go s.publisher.Run(&lib.StopChannelContext{StopCh: s.shutdownCh})
if s.config.ConnectMeshGatewayWANFederationEnabled { if s.config.ConnectMeshGatewayWANFederationEnabled {
s.gatewayLocator = NewGatewayLocator( s.gatewayLocator = NewGatewayLocator(
s.logger, s.logger,
@ -652,6 +672,7 @@ func NewServer(config *Config, flat Deps, publicGRPCServer *grpc.Server) (*Serve
// Initialize public gRPC server - register services on public gRPC server. // Initialize public gRPC server - register services on public gRPC server.
connectca.NewServer(connectca.Config{ connectca.NewServer(connectca.Config{
Publisher: s.publisher,
GetStore: func() connectca.StateStore { return s.FSM().State() }, GetStore: func() connectca.StateStore { return s.FSM().State() },
Logger: logger.Named("grpc-api.connect-ca"), Logger: logger.Named("grpc-api.connect-ca"),
ACLResolver: plainACLResolver{s.ACLResolver}, ACLResolver: plainACLResolver{s.ACLResolver},
@ -684,21 +705,6 @@ func NewServer(config *Config, flat Deps, publicGRPCServer *grpc.Server) (*Serve
return s, nil return s, nil
} }
func newFSMFromConfig(logger hclog.Logger, gc *state.TombstoneGC, config *Config) *fsm.FSM {
deps := fsm.Deps{Logger: logger}
if config.RPCConfig.EnableStreaming {
deps.NewStateStore = func() *state.Store {
return state.NewStateStoreWithEventPublisher(gc)
}
return fsm.NewFromDeps(deps)
}
deps.NewStateStore = func() *state.Store {
return state.NewStateStore(gc)
}
return fsm.NewFromDeps(deps)
}
func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler { func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler {
register := func(srv *grpc.Server) { register := func(srv *grpc.Server) {
if config.RPCConfig.EnableStreaming { if config.RPCConfig.EnableStreaming {

View File

@ -78,12 +78,11 @@ func (e EventPayloadCheckServiceNode) Subject() stream.Subject {
// serviceHealthSnapshot returns a stream.SnapshotFunc that provides a snapshot // serviceHealthSnapshot returns a stream.SnapshotFunc that provides a snapshot
// of stream.Events that describe the current state of a service health query. // of stream.Events that describe the current state of a service health query.
func serviceHealthSnapshot(db ReadDB, topic stream.Topic) stream.SnapshotFunc { func (s *Store) ServiceHealthSnapshot(req stream.SubscribeRequest, buf stream.SnapshotAppender) (index uint64, err error) {
return func(req stream.SubscribeRequest, buf stream.SnapshotAppender) (index uint64, err error) { tx := s.db.ReadTxn()
tx := db.ReadTxn()
defer tx.Abort() defer tx.Abort()
connect := topic == topicServiceHealthConnect connect := req.Topic == EventTopicServiceHealthConnect
subject, ok := req.Subject.(EventSubjectService) subject, ok := req.Subject.(EventSubjectService)
if !ok { if !ok {
@ -99,7 +98,7 @@ func serviceHealthSnapshot(db ReadDB, topic stream.Topic) stream.SnapshotFunc {
n := nodes[i] n := nodes[i]
event := stream.Event{ event := stream.Event{
Index: idx, Index: idx,
Topic: topic, Topic: req.Topic,
Payload: EventPayloadCheckServiceNode{ Payload: EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Register, Op: pbsubscribe.CatalogOp_Register,
Value: &n, Value: &n,
@ -121,7 +120,6 @@ func serviceHealthSnapshot(db ReadDB, topic stream.Topic) stream.SnapshotFunc {
} }
return idx, err return idx, err
}
} }
// TODO: this could use NodeServiceQuery // TODO: this could use NodeServiceQuery
@ -355,7 +353,7 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event
for _, sn := range nodes { for _, sn := range nodes {
e := newServiceHealthEventDeregister(changes.Index, sn) e := newServiceHealthEventDeregister(changes.Index, sn)
e.Topic = topicServiceHealthConnect e.Topic = EventTopicServiceHealthConnect
payload := e.Payload.(EventPayloadCheckServiceNode) payload := e.Payload.(EventPayloadCheckServiceNode)
payload.overrideKey = serviceName.Name payload.overrideKey = serviceName.Name
if gatewayName.EnterpriseMeta.NamespaceOrDefault() != serviceName.EnterpriseMeta.NamespaceOrDefault() { if gatewayName.EnterpriseMeta.NamespaceOrDefault() != serviceName.EnterpriseMeta.NamespaceOrDefault() {
@ -388,7 +386,7 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event
return nil, err return nil, err
} }
e.Topic = topicServiceHealthConnect e.Topic = EventTopicServiceHealthConnect
payload := e.Payload.(EventPayloadCheckServiceNode) payload := e.Payload.(EventPayloadCheckServiceNode)
payload.overrideKey = serviceName.Name payload.overrideKey = serviceName.Name
if gatewayName.EnterpriseMeta.NamespaceOrDefault() != serviceName.EnterpriseMeta.NamespaceOrDefault() { if gatewayName.EnterpriseMeta.NamespaceOrDefault() != serviceName.EnterpriseMeta.NamespaceOrDefault() {
@ -426,7 +424,7 @@ func isConnectProxyDestinationServiceChange(idx uint64, before, after *structs.S
} }
e := newServiceHealthEventDeregister(idx, before) e := newServiceHealthEventDeregister(idx, before)
e.Topic = topicServiceHealthConnect e.Topic = EventTopicServiceHealthConnect
payload := e.Payload.(EventPayloadCheckServiceNode) payload := e.Payload.(EventPayloadCheckServiceNode)
payload.overrideKey = payload.Value.Service.Proxy.DestinationServiceName payload.overrideKey = payload.Value.Service.Proxy.DestinationServiceName
e.Payload = payload e.Payload = payload
@ -467,7 +465,7 @@ func serviceHealthToConnectEvents(
) ([]stream.Event, error) { ) ([]stream.Event, error) {
var result []stream.Event var result []stream.Event
for _, event := range events { for _, event := range events {
if event.Topic != topicServiceHealth { // event.Topic == topicServiceHealthConnect if event.Topic != EventTopicServiceHealth { // event.Topic == topicServiceHealthConnect
// Skip non-health or any events already emitted to Connect topic // Skip non-health or any events already emitted to Connect topic
continue continue
} }
@ -490,7 +488,7 @@ func connectEventsByServiceKind(tx ReadTxn, origEvent stream.Event) ([]stream.Ev
} }
event := origEvent // shallow copy the event event := origEvent // shallow copy the event
event.Topic = topicServiceHealthConnect event.Topic = EventTopicServiceHealthConnect
if node.Service.Connect.Native { if node.Service.Connect.Native {
return []stream.Event{event}, nil return []stream.Event{event}, nil
@ -527,7 +525,7 @@ func connectEventsByServiceKind(tx ReadTxn, origEvent stream.Event) ([]stream.Ev
} }
func copyEventForService(event stream.Event, service structs.ServiceName) stream.Event { func copyEventForService(event stream.Event, service structs.ServiceName) stream.Event {
event.Topic = topicServiceHealthConnect event.Topic = EventTopicServiceHealthConnect
payload := event.Payload.(EventPayloadCheckServiceNode) payload := event.Payload.(EventPayloadCheckServiceNode)
payload.overrideKey = service.Name payload.overrideKey = service.Name
if payload.Value.Service.EnterpriseMeta.NamespaceOrDefault() != service.EnterpriseMeta.NamespaceOrDefault() { if payload.Value.Service.EnterpriseMeta.NamespaceOrDefault() != service.EnterpriseMeta.NamespaceOrDefault() {
@ -666,7 +664,7 @@ func newServiceHealthEventRegister(
Checks: checks, Checks: checks,
} }
return stream.Event{ return stream.Event{
Topic: topicServiceHealth, Topic: EventTopicServiceHealth,
Index: idx, Index: idx,
Payload: EventPayloadCheckServiceNode{ Payload: EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Register, Op: pbsubscribe.CatalogOp_Register,
@ -697,7 +695,7 @@ func newServiceHealthEventDeregister(idx uint64, sn *structs.ServiceNode) stream
} }
return stream.Event{ return stream.Event{
Topic: topicServiceHealth, Topic: EventTopicServiceHealth,
Index: idx, Index: idx,
Payload: EventPayloadCheckServiceNode{ Payload: EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Deregister, Op: pbsubscribe.CatalogOp_Deregister,

View File

@ -70,11 +70,10 @@ func TestServiceHealthSnapshot(t *testing.T) {
err = store.EnsureRegistration(counter.Next(), testServiceRegistration(t, "web", regNode2)) err = store.EnsureRegistration(counter.Next(), testServiceRegistration(t, "web", regNode2))
require.NoError(t, err) require.NoError(t, err)
fn := serviceHealthSnapshot((*readDB)(store.db.db), topicServiceHealth)
buf := &snapshotAppender{} buf := &snapshotAppender{}
req := stream.SubscribeRequest{Subject: EventSubjectService{Key: "web"}} req := stream.SubscribeRequest{Topic: EventTopicServiceHealth, Subject: EventSubjectService{Key: "web"}}
idx, err := fn(req, buf) idx, err := store.ServiceHealthSnapshot(req, buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, counter.Last(), idx) require.Equal(t, counter.Last(), idx)
@ -147,11 +146,10 @@ func TestServiceHealthSnapshot_ConnectTopic(t *testing.T) {
err = store.EnsureRegistration(counter.Next(), testServiceRegistration(t, "tgate1", regTerminatingGateway)) err = store.EnsureRegistration(counter.Next(), testServiceRegistration(t, "tgate1", regTerminatingGateway))
require.NoError(t, err) require.NoError(t, err)
fn := serviceHealthSnapshot((*readDB)(store.db.db), topicServiceHealthConnect)
buf := &snapshotAppender{} buf := &snapshotAppender{}
req := stream.SubscribeRequest{Subject: EventSubjectService{Key: "web"}, Topic: topicServiceHealthConnect} req := stream.SubscribeRequest{Subject: EventSubjectService{Key: "web"}, Topic: EventTopicServiceHealthConnect}
idx, err := fn(req, buf) idx, err := store.ServiceHealthSnapshot(req, buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, counter.Last(), idx) require.Equal(t, counter.Last(), idx)
@ -1743,7 +1741,7 @@ func evServiceTermingGateway(name string) func(e *stream.Event) error {
} }
} }
if e.Topic == topicServiceHealthConnect { if e.Topic == EventTopicServiceHealthConnect {
payload := e.Payload.(EventPayloadCheckServiceNode) payload := e.Payload.(EventPayloadCheckServiceNode)
payload.overrideKey = name payload.overrideKey = name
e.Payload = payload e.Payload = payload
@ -2096,7 +2094,7 @@ func evConnectNative(e *stream.Event) error {
// depending on which topic they are published to and they determine this from // depending on which topic they are published to and they determine this from
// the event. // the event.
func evConnectTopic(e *stream.Event) error { func evConnectTopic(e *stream.Event) error {
e.Topic = topicServiceHealthConnect e.Topic = EventTopicServiceHealthConnect
return nil return nil
} }
@ -2135,7 +2133,7 @@ func evSidecar(e *stream.Event) error {
csn.Checks[1].ServiceName = svc + "_sidecar_proxy" csn.Checks[1].ServiceName = svc + "_sidecar_proxy"
} }
if e.Topic == topicServiceHealthConnect { if e.Topic == EventTopicServiceHealthConnect {
payload := e.Payload.(EventPayloadCheckServiceNode) payload := e.Payload.(EventPayloadCheckServiceNode)
payload.overrideKey = svc payload.overrideKey = svc
e.Payload = payload e.Payload = payload
@ -2238,7 +2236,7 @@ func evRenameService(e *stream.Event) error {
taggedAddr.Address = "240.0.0.2" taggedAddr.Address = "240.0.0.2"
csn.Service.TaggedAddresses[structs.TaggedAddressVirtualIP] = taggedAddr csn.Service.TaggedAddresses[structs.TaggedAddressVirtualIP] = taggedAddr
if e.Topic == topicServiceHealthConnect { if e.Topic == EventTopicServiceHealthConnect {
payload := e.Payload.(EventPayloadCheckServiceNode) payload := e.Payload.(EventPayloadCheckServiceNode)
payload.overrideKey = csn.Service.Proxy.DestinationServiceName payload.overrideKey = csn.Service.Proxy.DestinationServiceName
e.Payload = payload e.Payload = payload
@ -2350,7 +2348,7 @@ func newTestEventServiceHealthRegister(index uint64, nodeNum int, svc string) st
addr := fmt.Sprintf("10.10.%d.%d", nodeNum/256, nodeNum%256) addr := fmt.Sprintf("10.10.%d.%d", nodeNum/256, nodeNum%256)
return stream.Event{ return stream.Event{
Topic: topicServiceHealth, Topic: EventTopicServiceHealth,
Index: index, Index: index,
Payload: EventPayloadCheckServiceNode{ Payload: EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Register, Op: pbsubscribe.CatalogOp_Register,
@ -2421,7 +2419,7 @@ func newTestEventServiceHealthRegister(index uint64, nodeNum int, svc string) st
// adding too many options to callers. // adding too many options to callers.
func newTestEventServiceHealthDeregister(index uint64, nodeNum int, svc string) stream.Event { func newTestEventServiceHealthDeregister(index uint64, nodeNum int, svc string) stream.Event {
return stream.Event{ return stream.Event{
Topic: topicServiceHealth, Topic: EventTopicServiceHealth,
Index: index, Index: index,
Payload: EventPayloadCheckServiceNode{ Payload: EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Deregister, Op: pbsubscribe.CatalogOp_Deregister,

View File

@ -65,9 +65,8 @@ func caRootsChangeEvents(tx ReadTxn, changes Changes) ([]stream.Event, error) {
// caRootsSnapshot returns a stream.SnapshotFunc that provides a snapshot of // caRootsSnapshot returns a stream.SnapshotFunc that provides a snapshot of
// the current active list of CA Roots. // the current active list of CA Roots.
func caRootsSnapshot(db ReadDB) stream.SnapshotFunc { func (s *Store) CARootsSnapshot(_ stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) {
return func(_ stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) { tx := s.db.ReadTxn()
tx := db.ReadTxn()
defer tx.Abort() defer tx.Abort()
idx, roots, err := caRootsTxn(tx, nil) idx, roots, err := caRootsTxn(tx, nil)
@ -83,5 +82,4 @@ func caRootsSnapshot(db ReadDB) stream.SnapshotFunc {
}, },
}) })
return idx, nil return idx, nil
}
} }

View File

@ -51,14 +51,13 @@ func TestCARootsEvents(t *testing.T) {
func TestCARootsSnapshot(t *testing.T) { func TestCARootsSnapshot(t *testing.T) {
store := testStateStore(t) store := testStateStore(t)
fn := caRootsSnapshot((*readDB)(store.db.db))
var req stream.SubscribeRequest var req stream.SubscribeRequest
t.Run("no roots", func(t *testing.T) { t.Run("no roots", func(t *testing.T) {
buf := &snapshotAppender{} buf := &snapshotAppender{}
idx, err := fn(req, buf) idx, err := store.CARootsSnapshot(req, buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, uint64(0), idx) require.Equal(t, uint64(0), idx)
@ -77,7 +76,7 @@ func TestCARootsSnapshot(t *testing.T) {
_, err := store.CARootSetCAS(1, 0, structs.CARoots{root}) _, err := store.CARootSetCAS(1, 0, structs.CARoots{root})
require.NoError(t, err) require.NoError(t, err)
idx, err := fn(req, buf) idx, err := store.CARootsSnapshot(req, buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, uint64(1), idx) require.Equal(t, uint64(1), idx)

View File

@ -1,7 +1,6 @@
package state package state
import ( import (
"context"
"fmt" "fmt"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
@ -58,7 +57,7 @@ type changeTrackerDB struct {
type EventPublisher interface { type EventPublisher interface {
Publish([]stream.Event) Publish([]stream.Event)
Run(context.Context) RegisterHandler(stream.Topic, stream.SnapshotFunc) error
Subscribe(*stream.SubscribeRequest) (*stream.Subscription, error) Subscribe(*stream.SubscribeRequest) (*stream.Subscription, error)
} }
@ -179,8 +178,8 @@ func (db *readDB) ReadTxn() AbortTxn {
} }
var ( var (
topicServiceHealth = pbsubscribe.Topic_ServiceHealth EventTopicServiceHealth = pbsubscribe.Topic_ServiceHealth
topicServiceHealthConnect = pbsubscribe.Topic_ServiceHealthConnect EventTopicServiceHealthConnect = pbsubscribe.Topic_ServiceHealthConnect
) )
func processDBChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) { func processDBChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) {
@ -200,11 +199,3 @@ func processDBChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) {
} }
return events, nil return events, nil
} }
func newSnapshotHandlers(db ReadDB) stream.SnapshotHandlers {
return stream.SnapshotHandlers{
topicServiceHealth: serviceHealthSnapshot(db, topicServiceHealth),
topicServiceHealthConnect: serviceHealthSnapshot(db, topicServiceHealthConnect),
EventTopicCARoots: caRootsSnapshot(db),
}
}

View File

@ -1,10 +1,8 @@
package state package state
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"time"
memdb "github.com/hashicorp/go-memdb" memdb "github.com/hashicorp/go-memdb"
@ -109,10 +107,6 @@ type Store struct {
// abandoned (usually during a restore). This is only ever closed. // abandoned (usually during a restore). This is only ever closed.
abandonCh chan struct{} abandonCh chan struct{}
// TODO: refactor abondonCh to use a context so that both can use the same
// cancel mechanism.
stopEventPublisher func()
// kvsGraveyard manages tombstones for the key value store. // kvsGraveyard manages tombstones for the key value store.
kvsGraveyard *Graveyard kvsGraveyard *Graveyard
@ -163,7 +157,6 @@ func NewStateStore(gc *TombstoneGC) *Store {
abandonCh: make(chan struct{}), abandonCh: make(chan struct{}),
kvsGraveyard: NewGraveyard(gc), kvsGraveyard: NewGraveyard(gc),
lockDelay: NewDelay(), lockDelay: NewDelay(),
stopEventPublisher: func() {},
db: &changeTrackerDB{ db: &changeTrackerDB{
db: db, db: db,
publisher: stream.NoOpEventPublisher{}, publisher: stream.NoOpEventPublisher{},
@ -173,24 +166,13 @@ func NewStateStore(gc *TombstoneGC) *Store {
return s return s
} }
func NewStateStoreWithEventPublisher(gc *TombstoneGC) *Store { func NewStateStoreWithEventPublisher(gc *TombstoneGC, publisher EventPublisher) *Store {
store := NewStateStore(gc) store := NewStateStore(gc)
ctx, cancel := context.WithCancel(context.TODO()) store.db.publisher = publisher
store.stopEventPublisher = cancel
pub := stream.NewEventPublisher(newSnapshotHandlers((*readDB)(store.db.db)), 10*time.Second)
store.db.publisher = pub
go pub.Run(ctx)
return store return store
} }
// EventPublisher returns the stream.EventPublisher used by the Store to
// publish events.
func (s *Store) EventPublisher() EventPublisher {
return s.db.publisher
}
// Snapshot is used to create a point-in-time snapshot of the entire db. // Snapshot is used to create a point-in-time snapshot of the entire db.
func (s *Store) Snapshot() *Snapshot { func (s *Store) Snapshot() *Snapshot {
tx := s.db.Txn(false) tx := s.db.Txn(false)
@ -277,11 +259,7 @@ func (s *Store) AbandonCh() <-chan struct{} {
// Abandon is used to signal that the given state store has been abandoned. // Abandon is used to signal that the given state store has been abandoned.
// Calling this more than one time will panic. // Calling this more than one time will panic.
func (s *Store) Abandon() { func (s *Store) Abandon() {
// Note: the order of these operations matters. Subscribers may receive on
// abandonCh to determine whether their subscription was closed because the
// store was abandoned, therefore it's important abandonCh is closed first.
close(s.abandonCh) close(s.abandonCh)
s.stopEventPublisher()
} }
// maxIndex is a helper used to retrieve the highest known index // maxIndex is a helper used to retrieve the highest known index

View File

@ -32,7 +32,8 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
publisher := stream.NewEventPublisher(newTestSnapshotHandlers(s), 0) publisher := stream.NewEventPublisher(0)
registerTestSnapshotHandlers(t, s, publisher)
go publisher.Run(ctx) go publisher.Run(ctx)
s.db.publisher = publisher s.db.publisher = publisher
sub, err := publisher.Subscribe(subscription) sub, err := publisher.Subscribe(subscription)
@ -119,7 +120,8 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
publisher := stream.NewEventPublisher(newTestSnapshotHandlers(s), 0) publisher := stream.NewEventPublisher(0)
registerTestSnapshotHandlers(t, s, publisher)
go publisher.Run(ctx) go publisher.Run(ctx)
s.db.publisher = publisher s.db.publisher = publisher
sub, err := publisher.Subscribe(subscription) sub, err := publisher.Subscribe(subscription)
@ -240,7 +242,8 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
publisher := stream.NewEventPublisher(newTestSnapshotHandlers(s), 0) publisher := stream.NewEventPublisher(0)
registerTestSnapshotHandlers(t, s, publisher)
go publisher.Run(ctx) go publisher.Run(ctx)
s.db.publisher = publisher s.db.publisher = publisher
sub, err := publisher.Subscribe(subscription) sub, err := publisher.Subscribe(subscription)
@ -393,9 +396,7 @@ func (t topic) String() string {
var topicService topic = "test-topic-service" var topicService topic = "test-topic-service"
func newTestSnapshotHandlers(s *Store) stream.SnapshotHandlers { func (s *Store) topicServiceTestHandler(req stream.SubscribeRequest, snap stream.SnapshotAppender) (uint64, error) {
return stream.SnapshotHandlers{
topicService: func(req stream.SubscribeRequest, snap stream.SnapshotAppender) (uint64, error) {
key := req.Subject.String() key := req.Subject.String()
idx, nodes, err := s.ServiceNodes(nil, key, nil) idx, nodes, err := s.ServiceNodes(nil, key, nil)
@ -412,8 +413,12 @@ func newTestSnapshotHandlers(s *Store) stream.SnapshotHandlers {
snap.Append([]stream.Event{event}) snap.Append([]stream.Event{event})
} }
return idx, nil return idx, nil
}, }
}
func registerTestSnapshotHandlers(t *testing.T, s *Store, publisher EventPublisher) {
t.Helper()
err := publisher.RegisterHandler(topicService, s.topicServiceTestHandler)
require.NoError(t, err)
} }
type nodePayload struct { type nodePayload struct {
@ -460,7 +465,8 @@ func createTokenAndWaitForACLEventPublish(t *testing.T, s *Store) *structs.ACLTo
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
publisher := stream.NewEventPublisher(newTestSnapshotHandlers(s), 0) publisher := stream.NewEventPublisher(0)
registerTestSnapshotHandlers(t, s, publisher)
go publisher.Run(ctx) go publisher.Run(ctx)
s.db.publisher = publisher s.db.publisher = publisher

View File

@ -91,7 +91,7 @@ type SnapshotAppender interface {
// A goroutine is run in the background to publish events to all subscribes. // A goroutine is run in the background to publish events to all subscribes.
// Cancelling the context will shutdown the goroutine, to free resources, // Cancelling the context will shutdown the goroutine, to free resources,
// and stop all publishing. // and stop all publishing.
func NewEventPublisher(handlers SnapshotHandlers, snapCacheTTL time.Duration) *EventPublisher { func NewEventPublisher(snapCacheTTL time.Duration) *EventPublisher {
e := &EventPublisher{ e := &EventPublisher{
snapCacheTTL: snapCacheTTL, snapCacheTTL: snapCacheTTL,
topicBuffers: make(map[topicSubject]*topicBuffer), topicBuffers: make(map[topicSubject]*topicBuffer),
@ -100,12 +100,41 @@ func NewEventPublisher(handlers SnapshotHandlers, snapCacheTTL time.Duration) *E
subscriptions: &subscriptions{ subscriptions: &subscriptions{
byToken: make(map[string]map[*SubscribeRequest]*Subscription), byToken: make(map[string]map[*SubscribeRequest]*Subscription),
}, },
snapshotHandlers: handlers, snapshotHandlers: make(map[Topic]SnapshotFunc),
} }
return e return e
} }
// RegisterHandler will register a new snapshot handler function. The expectation is
// that all handlers get registered prior to the event publisher being Run. Handler
// registration is therefore not concurrency safe and access to handlers is internally
// not synchronized.
func (e *EventPublisher) RegisterHandler(topic Topic, handler SnapshotFunc) error {
if topic.String() == "" {
return fmt.Errorf("the topic cannnot be empty")
}
if _, found := e.snapshotHandlers[topic]; found {
return fmt.Errorf("a handler is already registered for the topic: %s", topic.String())
}
e.snapshotHandlers[topic] = handler
return nil
}
func (e *EventPublisher) RefreshTopic(topic Topic) error {
if _, found := e.snapshotHandlers[topic]; !found {
return fmt.Errorf("topic %s is not registered", topic)
}
e.forceEvictByTopic(topic)
e.subscriptions.closeAllByTopic(topic)
return nil
}
// Publish events to all subscribers of the event Topic. The events will be shared // Publish events to all subscribers of the event Topic. The events will be shared
// with all subscriptions, so the Payload used in Event.Payload must be immutable. // with all subscriptions, so the Payload used in Event.Payload must be immutable.
func (e *EventPublisher) Publish(events []Event) { func (e *EventPublisher) Publish(events []Event) {
@ -196,14 +225,14 @@ func (e *EventPublisher) bufferForPublishing(key topicSubject) *eventBuffer {
// When the caller is finished with the subscription for any reason, it must // When the caller is finished with the subscription for any reason, it must
// call Subscription.Unsubscribe to free ACL tracking resources. // call Subscription.Unsubscribe to free ACL tracking resources.
func (e *EventPublisher) Subscribe(req *SubscribeRequest) (*Subscription, error) { func (e *EventPublisher) Subscribe(req *SubscribeRequest) (*Subscription, error) {
e.lock.Lock()
defer e.lock.Unlock()
handler, ok := e.snapshotHandlers[req.Topic] handler, ok := e.snapshotHandlers[req.Topic]
if !ok || req.Topic == nil { if !ok || req.Topic == nil {
return nil, fmt.Errorf("unknown topic %v", req.Topic) return nil, fmt.Errorf("unknown topic %v", req.Topic)
} }
e.lock.Lock()
defer e.lock.Unlock()
topicBuf := e.bufferForSubscription(req.topicSubject()) topicBuf := e.bufferForSubscription(req.topicSubject())
topicBuf.refs++ topicBuf.refs++
@ -327,6 +356,19 @@ func (s *subscriptions) closeAll() {
} }
} }
func (s *subscriptions) closeAllByTopic(topic Topic) {
s.lock.Lock()
defer s.lock.Unlock()
for _, byRequest := range s.byToken {
for _, sub := range byRequest {
if sub.req.Topic == topic {
sub.forceClose()
}
}
}
}
// EventPublisher.lock must be held to call this method. // EventPublisher.lock must be held to call this method.
func (e *EventPublisher) getCachedSnapshotLocked(req *SubscribeRequest) *eventSnapshot { func (e *EventPublisher) getCachedSnapshotLocked(req *SubscribeRequest) *eventSnapshot {
snap, ok := e.snapCache[req.topicSubject()] snap, ok := e.snapCache[req.topicSubject()]
@ -350,3 +392,15 @@ func (e *EventPublisher) setCachedSnapshotLocked(req *SubscribeRequest, snap *ev
delete(e.snapCache, req.topicSubject()) delete(e.snapCache, req.topicSubject())
}) })
} }
// forceEvictByTopic will remove all entries from the snapshot cache for a given topic.
// This method should be called while holding the publishers lock.
func (e *EventPublisher) forceEvictByTopic(topic Topic) {
e.lock.Lock()
for key := range e.snapCache {
if key.Topic == topic.String() {
delete(e.snapCache, key)
}
}
e.lock.Unlock()
}

View File

@ -27,7 +27,8 @@ func TestEventPublisher_SubscribeWithIndex0(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
publisher := NewEventPublisher(newTestSnapshotHandlers(), 0) publisher := NewEventPublisher(0)
registerTestSnapshotHandlers(t, publisher)
go publisher.Run(ctx) go publisher.Run(ctx)
sub, err := publisher.Subscribe(req) sub, err := publisher.Subscribe(req)
@ -83,16 +84,18 @@ func (p simplePayload) HasReadPermission(acl.Authorizer) bool {
func (p simplePayload) Subject() Subject { return stringer(p.key) } func (p simplePayload) Subject() Subject { return stringer(p.key) }
func newTestSnapshotHandlers() SnapshotHandlers { func registerTestSnapshotHandlers(t *testing.T, publisher *EventPublisher) {
return SnapshotHandlers{ t.Helper()
testTopic: func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) {
testTopicHandler := func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) {
if req.Topic != testTopic { if req.Topic != testTopic {
return 0, fmt.Errorf("unexpected topic: %v", req.Topic) return 0, fmt.Errorf("unexpected topic: %v", req.Topic)
} }
buf.Append([]Event{testSnapshotEvent}) buf.Append([]Event{testSnapshotEvent})
return 1, nil return 1, nil
},
} }
require.NoError(t, publisher.RegisterHandler(testTopic, testTopicHandler))
} }
func runSubscription(ctx context.Context, sub *Subscription) <-chan eventOrErr { func runSubscription(ctx context.Context, sub *Subscription) <-chan eventOrErr {
@ -143,14 +146,14 @@ func TestEventPublisher_ShutdownClosesSubscriptions(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel) t.Cleanup(cancel)
handlers := newTestSnapshotHandlers()
fn := func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) { fn := func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) {
return 0, nil return 0, nil
} }
handlers[intTopic(22)] = fn
handlers[intTopic(33)] = fn
publisher := NewEventPublisher(handlers, time.Second) publisher := NewEventPublisher(time.Second)
registerTestSnapshotHandlers(t, publisher)
publisher.RegisterHandler(intTopic(22), fn)
publisher.RegisterHandler(intTopic(33), fn)
go publisher.Run(ctx) go publisher.Run(ctx)
sub1, err := publisher.Subscribe(&SubscribeRequest{Topic: intTopic(22), Subject: SubjectNone}) sub1, err := publisher.Subscribe(&SubscribeRequest{Topic: intTopic(22), Subject: SubjectNone})
@ -190,7 +193,8 @@ func TestEventPublisher_SubscribeWithIndex0_FromCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) publisher := NewEventPublisher(time.Second)
registerTestSnapshotHandlers(t, publisher)
go publisher.Run(ctx) go publisher.Run(ctx)
sub, err := publisher.Subscribe(req) sub, err := publisher.Subscribe(req)
@ -235,7 +239,8 @@ func TestEventPublisher_SubscribeWithIndexNotZero_CanResume(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) publisher := NewEventPublisher(time.Second)
registerTestSnapshotHandlers(t, publisher)
go publisher.Run(ctx) go publisher.Run(ctx)
simulateExistingSubscriber(t, publisher, req) simulateExistingSubscriber(t, publisher, req)
@ -288,7 +293,8 @@ func TestEventPublisher_SubscribeWithIndexNotZero_NewSnapshot(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
publisher := NewEventPublisher(newTestSnapshotHandlers(), 0) publisher := NewEventPublisher(0)
registerTestSnapshotHandlers(t, publisher)
go publisher.Run(ctx) go publisher.Run(ctx)
// Include the same event in the topicBuffer // Include the same event in the topicBuffer
publisher.publishEvent([]Event{testSnapshotEvent}) publisher.publishEvent([]Event{testSnapshotEvent})
@ -344,7 +350,8 @@ func TestEventPublisher_SubscribeWithIndexNotZero_NewSnapshotFromCache(t *testin
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) publisher := NewEventPublisher(time.Second)
registerTestSnapshotHandlers(t, publisher)
go publisher.Run(ctx) go publisher.Run(ctx)
simulateExistingSubscriber(t, publisher, req) simulateExistingSubscriber(t, publisher, req)
@ -417,21 +424,20 @@ func TestEventPublisher_SubscribeWithIndexNotZero_NewSnapshot_WithCache(t *testi
Payload: simplePayload{key: "sub-key", value: "event-3"}, Payload: simplePayload{key: "sub-key", value: "event-3"},
} }
handlers := SnapshotHandlers{ testTopicHandler := func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) {
testTopic: func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) {
if req.Topic != testTopic { if req.Topic != testTopic {
return 0, fmt.Errorf("unexpected topic: %v", req.Topic) return 0, fmt.Errorf("unexpected topic: %v", req.Topic)
} }
buf.Append([]Event{testSnapshotEvent}) buf.Append([]Event{testSnapshotEvent})
buf.Append([]Event{nextEvent}) buf.Append([]Event{nextEvent})
return 3, nil return 3, nil
},
} }
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
publisher := NewEventPublisher(handlers, time.Second) publisher := NewEventPublisher(time.Second)
publisher.RegisterHandler(testTopic, testTopicHandler)
go publisher.Run(ctx) go publisher.Run(ctx)
simulateExistingSubscriber(t, publisher, req) simulateExistingSubscriber(t, publisher, req)
@ -498,7 +504,8 @@ func TestEventPublisher_Unsubscribe_ClosesSubscription(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) publisher := NewEventPublisher(time.Second)
registerTestSnapshotHandlers(t, publisher)
sub, err := publisher.Subscribe(req) sub, err := publisher.Subscribe(req)
require.NoError(t, err) require.NoError(t, err)
@ -518,7 +525,8 @@ func TestEventPublisher_Unsubscribe_FreesResourcesWhenThereAreNoSubscribers(t *t
Subject: stringer("sub-key"), Subject: stringer("sub-key"),
} }
publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) publisher := NewEventPublisher(time.Second)
registerTestSnapshotHandlers(t, publisher)
sub1, err := publisher.Subscribe(req) sub1, err := publisher.Subscribe(req)
require.NoError(t, err) require.NoError(t, err)

View File

@ -9,6 +9,10 @@ type NoOpEventPublisher struct{}
func (NoOpEventPublisher) Publish([]Event) {} func (NoOpEventPublisher) Publish([]Event) {}
func (NoOpEventPublisher) RegisterHandler(Topic, SnapshotFunc) error {
return fmt.Errorf("stream event publisher is disabled")
}
func (NoOpEventPublisher) Run(context.Context) {} func (NoOpEventPublisher) Run(context.Context) {}
func (NoOpEventPublisher) Subscribe(*SubscribeRequest) (*Subscription, error) { func (NoOpEventPublisher) Subscribe(*SubscribeRequest) (*Subscription, error) {

View File

@ -31,5 +31,5 @@ func (s subscribeBackend) Forward(info structs.RPCInfo, f func(*grpc.ClientConn)
} }
func (s subscribeBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) { func (s subscribeBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) {
return s.srv.fsm.State().EventPublisher().Subscribe(req) return s.srv.publisher.Subscribe(req)
} }

View File

@ -32,8 +32,7 @@ import (
) )
func TestServer_Subscribe_KeyIsRequired(t *testing.T) { func TestServer_Subscribe_KeyIsRequired(t *testing.T) {
backend, err := newTestBackend() backend := newTestBackend(t)
require.NoError(t, err)
addr := runTestServer(t, NewServer(backend, hclog.New(nil))) addr := runTestServer(t, NewServer(backend, hclog.New(nil)))
@ -59,8 +58,7 @@ func TestServer_Subscribe_KeyIsRequired(t *testing.T) {
} }
func TestServer_Subscribe_IntegrationWithBackend(t *testing.T) { func TestServer_Subscribe_IntegrationWithBackend(t *testing.T) {
backend, err := newTestBackend() backend := newTestBackend(t)
require.NoError(t, err)
addr := runTestServer(t, NewServer(backend, hclog.New(nil))) addr := runTestServer(t, NewServer(backend, hclog.New(nil)))
ids := newCounter() ids := newCounter()
@ -312,6 +310,7 @@ func getEvent(t *testing.T, ch chan eventOrError) *pbsubscribe.Event {
} }
type testBackend struct { type testBackend struct {
publisher *stream.EventPublisher
store *state.Store store *state.Store
authorizer func(token string, entMeta *acl.EnterpriseMeta) acl.Authorizer authorizer func(token string, entMeta *acl.EnterpriseMeta) acl.Authorizer
forwardConn *gogrpc.ClientConn forwardConn *gogrpc.ClientConn
@ -333,19 +332,33 @@ func (b testBackend) Forward(_ structs.RPCInfo, fn func(*gogrpc.ClientConn) erro
} }
func (b testBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) { func (b testBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) {
return b.store.EventPublisher().Subscribe(req) return b.publisher.Subscribe(req)
} }
func newTestBackend() (*testBackend, error) { func newTestBackend(t *testing.T) *testBackend {
t.Helper()
gc, err := state.NewTombstoneGC(time.Second, time.Millisecond) gc, err := state.NewTombstoneGC(time.Second, time.Millisecond)
if err != nil { require.NoError(t, err)
return nil, err
} publisher := stream.NewEventPublisher(10 * time.Second)
store := state.NewStateStoreWithEventPublisher(gc)
store := state.NewStateStoreWithEventPublisher(gc, publisher)
// normally the handlers are registered on the FSM as state stores may come
// and go during snapshot restores. For the purposes of this test backend though we
// just register them directly to
require.NoError(t, publisher.RegisterHandler(state.EventTopicCARoots, store.CARootsSnapshot))
require.NoError(t, publisher.RegisterHandler(state.EventTopicServiceHealth, store.ServiceHealthSnapshot))
require.NoError(t, publisher.RegisterHandler(state.EventTopicServiceHealthConnect, store.ServiceHealthSnapshot))
ctx, cancel := context.WithCancel(context.Background())
go publisher.Run(ctx)
t.Cleanup(cancel)
allowAll := func(string, *acl.EnterpriseMeta) acl.Authorizer { allowAll := func(string, *acl.EnterpriseMeta) acl.Authorizer {
return acl.AllowAll() return acl.AllowAll()
} }
return &testBackend{store: store, authorizer: allowAll}, nil return &testBackend{publisher: publisher, store: store, authorizer: allowAll}
} }
var _ Backend = (*testBackend)(nil) var _ Backend = (*testBackend)(nil)
@ -409,12 +422,10 @@ func raftIndex(ids *counter, created, modified string) *pbcommon.RaftIndex {
} }
func TestServer_Subscribe_IntegrationWithBackend_ForwardToDC(t *testing.T) { func TestServer_Subscribe_IntegrationWithBackend_ForwardToDC(t *testing.T) {
backendLocal, err := newTestBackend() backendLocal := newTestBackend(t)
require.NoError(t, err)
addrLocal := runTestServer(t, NewServer(backendLocal, hclog.New(nil))) addrLocal := runTestServer(t, NewServer(backendLocal, hclog.New(nil)))
backendRemoteDC, err := newTestBackend() backendRemoteDC := newTestBackend(t)
require.NoError(t, err)
srvRemoteDC := NewServer(backendRemoteDC, hclog.New(nil)) srvRemoteDC := NewServer(backendRemoteDC, hclog.New(nil))
addrRemoteDC := runTestServer(t, srvRemoteDC) addrRemoteDC := runTestServer(t, srvRemoteDC)
@ -642,8 +653,7 @@ func TestServer_Subscribe_IntegrationWithBackend_FilterEventsByACLToken(t *testi
t.Skip("too slow for -short run") t.Skip("too slow for -short run")
} }
backend, err := newTestBackend() backend := newTestBackend(t)
require.NoError(t, err)
addr := runTestServer(t, NewServer(backend, hclog.New(nil))) addr := runTestServer(t, NewServer(backend, hclog.New(nil)))
token := "this-token-is-good" token := "this-token-is-good"
@ -839,8 +849,7 @@ node "node1" {
} }
func TestServer_Subscribe_IntegrationWithBackend_ACLUpdate(t *testing.T) { func TestServer_Subscribe_IntegrationWithBackend_ACLUpdate(t *testing.T) {
backend, err := newTestBackend() backend := newTestBackend(t)
require.NoError(t, err)
addr := runTestServer(t, NewServer(backend, hclog.New(nil))) addr := runTestServer(t, NewServer(backend, hclog.New(nil)))
token := "this-token-is-good" token := "this-token-is-good"
@ -1100,12 +1109,12 @@ func newPayloadEvents(items ...stream.Event) *stream.PayloadEvents {
func newEventFromSubscription(t *testing.T, index uint64) stream.Event { func newEventFromSubscription(t *testing.T, index uint64) stream.Event {
t.Helper() t.Helper()
handlers := map[stream.Topic]stream.SnapshotFunc{ serviceHealthConnectHandler := func(stream.SubscribeRequest, stream.SnapshotAppender) (index uint64, err error) {
pbsubscribe.Topic_ServiceHealthConnect: func(stream.SubscribeRequest, stream.SnapshotAppender) (index uint64, err error) {
return 1, nil return 1, nil
},
} }
ep := stream.NewEventPublisher(handlers, 0)
ep := stream.NewEventPublisher(0)
ep.RegisterHandler(pbsubscribe.Topic_ServiceHealthConnect, serviceHealthConnectHandler)
req := &stream.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealthConnect, Subject: stream.SubjectNone, Index: index} req := &stream.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealthConnect, Subject: stream.SubjectNone, Index: index}
sub, err := ep.Subscribe(req) sub, err := ep.Subscribe(req)

View File

@ -7,7 +7,7 @@ import (
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto-public/pbconnectca" "github.com/hashicorp/consul/proto-public/pbconnectca"
) )
@ -17,13 +17,17 @@ type Server struct {
} }
type Config struct { type Config struct {
Publisher EventPublisher
GetStore func() StateStore GetStore func() StateStore
Logger hclog.Logger Logger hclog.Logger
ACLResolver ACLResolver ACLResolver ACLResolver
} }
type EventPublisher interface {
Subscribe(*stream.SubscribeRequest) (*stream.Subscription, error)
}
type StateStore interface { type StateStore interface {
EventPublisher() state.EventPublisher
CAConfig(memdb.WatchSet) (uint64, *structs.CAConfiguration, error) CAConfig(memdb.WatchSet) (uint64, *structs.CAConfiguration, error)
AbandonCh() <-chan struct{} AbandonCh() <-chan struct{}
} }

View File

@ -3,6 +3,7 @@ package connectca
import ( import (
"context" "context"
"net" "net"
"sync"
"testing" "testing"
"time" "time"
@ -10,16 +11,66 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/proto-public/pbconnectca" "github.com/hashicorp/consul/proto-public/pbconnectca"
) )
func testStateStore(t *testing.T) *state.Store { func testStateStore(t *testing.T, publisher state.EventPublisher) *state.Store {
t.Helper() t.Helper()
gc, err := state.NewTombstoneGC(time.Second, time.Millisecond) gc, err := state.NewTombstoneGC(time.Second, time.Millisecond)
require.NoError(t, err) require.NoError(t, err)
return state.NewStateStoreWithEventPublisher(gc) return state.NewStateStoreWithEventPublisher(gc, publisher)
}
type FakeFSM struct {
lock sync.Mutex
store *state.Store
publisher *stream.EventPublisher
}
func newFakeFSM(t *testing.T, publisher *stream.EventPublisher) *FakeFSM {
t.Helper()
store := testStateStore(t, publisher)
fsm := FakeFSM{store: store, publisher: publisher}
// register handlers
publisher.RegisterHandler(state.EventTopicCARoots, func(req stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) {
return fsm.GetStore().CARootsSnapshot(req, buf)
})
return &fsm
}
func (f *FakeFSM) GetStore() *state.Store {
f.lock.Lock()
defer f.lock.Unlock()
return f.store
}
func (f *FakeFSM) ReplaceStore(store *state.Store) {
f.lock.Lock()
defer f.lock.Unlock()
oldStore := f.store
f.store = store
oldStore.Abandon()
f.publisher.RefreshTopic(state.EventTopicCARoots)
}
func setupFSMAndPublisher(t *testing.T) (*FakeFSM, state.EventPublisher) {
t.Helper()
publisher := stream.NewEventPublisher(10 * time.Second)
fsm := newFakeFSM(t, publisher)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go publisher.Run(ctx)
return fsm, publisher
} }
func testClient(t *testing.T, server *Server) pbconnectca.ConnectCAServiceClient { func testClient(t *testing.T, server *Server) pbconnectca.ConnectCAServiceClient {

View File

@ -68,7 +68,7 @@ func (s *Server) serveRoots(
} }
// Start the subscription. // Start the subscription.
sub, err := store.EventPublisher().Subscribe(&stream.SubscribeRequest{ sub, err := s.Publisher.Subscribe(&stream.SubscribeRequest{
Topic: state.EventTopicCARoots, Topic: state.EventTopicCARoots,
Subject: stream.SubjectNone, Subject: stream.SubjectNone,
Token: token, Token: token,

View File

@ -13,7 +13,6 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
@ -22,19 +21,20 @@ import (
"github.com/hashicorp/consul/agent/grpc/public/testutils" "github.com/hashicorp/consul/agent/grpc/public/testutils"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto-public/pbconnectca" "github.com/hashicorp/consul/proto-public/pbconnectca"
"github.com/hashicorp/consul/sdk/testutil"
) )
const testACLToken = "acl-token" const testACLToken = "acl-token"
func TestWatchRoots_Success(t *testing.T) { func TestWatchRoots_Success(t *testing.T) {
store := testStateStore(t) fsm, publisher := setupFSMAndPublisher(t)
// Set the initial roots and CA configuration. // Set the initial roots and CA configuration.
rootA := connect.TestCA(t, nil) rootA := connect.TestCA(t, nil)
_, err := store.CARootSetCAS(1, 0, structs.CARoots{rootA}) _, err := fsm.GetStore().CARootSetCAS(1, 0, structs.CARoots{rootA})
require.NoError(t, err) require.NoError(t, err)
err = store.CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-id"}) err = fsm.GetStore().CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-id"})
require.NoError(t, err) require.NoError(t, err)
// Mock the ACL Resolver to return an authorizer with `service:write`. // Mock the ACL Resolver to return an authorizer with `service:write`.
@ -45,8 +45,9 @@ func TestWatchRoots_Success(t *testing.T) {
ctx := public.ContextWithToken(context.Background(), testACLToken) ctx := public.ContextWithToken(context.Background(), testACLToken)
server := NewServer(Config{ server := NewServer(Config{
GetStore: func() StateStore { return store }, Publisher: publisher,
Logger: hclog.NewNullLogger(), GetStore: func() StateStore { return fsm.GetStore() },
Logger: testutil.Logger(t),
ACLResolver: aclResolver, ACLResolver: aclResolver,
}) })
@ -65,7 +66,7 @@ func TestWatchRoots_Success(t *testing.T) {
// Rotate the roots. // Rotate the roots.
rootB := connect.TestCA(t, nil) rootB := connect.TestCA(t, nil)
_, err = store.CARootSetCAS(2, 1, structs.CARoots{rootB}) _, err = fsm.GetStore().CARootSetCAS(2, 1, structs.CARoots{rootB})
require.NoError(t, err) require.NoError(t, err)
// Expect another event containing the new roots. // Expect another event containing the new roots.
@ -77,10 +78,10 @@ func TestWatchRoots_Success(t *testing.T) {
} }
func TestWatchRoots_InvalidACLToken(t *testing.T) { func TestWatchRoots_InvalidACLToken(t *testing.T) {
store := testStateStore(t) fsm, publisher := setupFSMAndPublisher(t)
// Set the initial CA configuration. // Set the initial CA configuration.
err := store.CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-id"}) err := fsm.GetStore().CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-id"})
require.NoError(t, err) require.NoError(t, err)
// Mock the ACL resolver to return ErrNotFound. // Mock the ACL resolver to return ErrNotFound.
@ -91,8 +92,9 @@ func TestWatchRoots_InvalidACLToken(t *testing.T) {
ctx := public.ContextWithToken(context.Background(), testACLToken) ctx := public.ContextWithToken(context.Background(), testACLToken)
server := NewServer(Config{ server := NewServer(Config{
GetStore: func() StateStore { return store }, Publisher: publisher,
Logger: hclog.NewNullLogger(), GetStore: func() StateStore { return fsm.GetStore() },
Logger: testutil.Logger(t),
ACLResolver: aclResolver, ACLResolver: aclResolver,
}) })
@ -108,14 +110,14 @@ func TestWatchRoots_InvalidACLToken(t *testing.T) {
} }
func TestWatchRoots_ACLTokenInvalidated(t *testing.T) { func TestWatchRoots_ACLTokenInvalidated(t *testing.T) {
store := testStateStore(t) fsm, publisher := setupFSMAndPublisher(t)
// Set the initial roots and CA configuration. // Set the initial roots and CA configuration.
rootA := connect.TestCA(t, nil) rootA := connect.TestCA(t, nil)
_, err := store.CARootSetCAS(1, 0, structs.CARoots{rootA}) _, err := fsm.GetStore().CARootSetCAS(1, 0, structs.CARoots{rootA})
require.NoError(t, err) require.NoError(t, err)
err = store.CASetConfig(2, &structs.CAConfiguration{ClusterID: "cluster-id"}) err = fsm.GetStore().CASetConfig(2, &structs.CAConfiguration{ClusterID: "cluster-id"})
require.NoError(t, err) require.NoError(t, err)
// Mock the ACL Resolver to return an authorizer with `service:write` the // Mock the ACL Resolver to return an authorizer with `service:write` the
@ -127,8 +129,9 @@ func TestWatchRoots_ACLTokenInvalidated(t *testing.T) {
ctx := public.ContextWithToken(context.Background(), testACLToken) ctx := public.ContextWithToken(context.Background(), testACLToken)
server := NewServer(Config{ server := NewServer(Config{
GetStore: func() StateStore { return store }, Publisher: publisher,
Logger: hclog.NewNullLogger(), GetStore: func() StateStore { return fsm.GetStore() },
Logger: testutil.Logger(t),
ACLResolver: aclResolver, ACLResolver: aclResolver,
}) })
@ -144,7 +147,7 @@ func TestWatchRoots_ACLTokenInvalidated(t *testing.T) {
// Update the ACL token to cause the subscription to be force-closed. // Update the ACL token to cause the subscription to be force-closed.
accessorID, err := uuid.GenerateUUID() accessorID, err := uuid.GenerateUUID()
require.NoError(t, err) require.NoError(t, err)
err = store.ACLTokenSet(1, &structs.ACLToken{ err = fsm.GetStore().ACLTokenSet(1, &structs.ACLToken{
AccessorID: accessorID, AccessorID: accessorID,
SecretID: testACLToken, SecretID: testACLToken,
}) })
@ -152,7 +155,7 @@ func TestWatchRoots_ACLTokenInvalidated(t *testing.T) {
// Update the roots. // Update the roots.
rootB := connect.TestCA(t, nil) rootB := connect.TestCA(t, nil)
_, err = store.CARootSetCAS(3, 1, structs.CARoots{rootB}) _, err = fsm.GetStore().CARootSetCAS(3, 1, structs.CARoots{rootB})
require.NoError(t, err) require.NoError(t, err)
// Expect the stream to remain open and to receive the new roots. // Expect the stream to remain open and to receive the new roots.
@ -163,7 +166,7 @@ func TestWatchRoots_ACLTokenInvalidated(t *testing.T) {
Return(acl.DenyAll(), nil) Return(acl.DenyAll(), nil)
// Update the ACL token to cause the subscription to be force-closed. // Update the ACL token to cause the subscription to be force-closed.
err = store.ACLTokenSet(1, &structs.ACLToken{ err = fsm.GetStore().ACLTokenSet(1, &structs.ACLToken{
AccessorID: accessorID, AccessorID: accessorID,
SecretID: testACLToken, SecretID: testACLToken,
}) })
@ -175,14 +178,14 @@ func TestWatchRoots_ACLTokenInvalidated(t *testing.T) {
} }
func TestWatchRoots_StateStoreAbandoned(t *testing.T) { func TestWatchRoots_StateStoreAbandoned(t *testing.T) {
storeA := testStateStore(t) fsm, publisher := setupFSMAndPublisher(t)
// Set the initial roots and CA configuration. // Set the initial roots and CA configuration.
rootA := connect.TestCA(t, nil) rootA := connect.TestCA(t, nil)
_, err := storeA.CARootSetCAS(1, 0, structs.CARoots{rootA}) _, err := fsm.GetStore().CARootSetCAS(1, 0, structs.CARoots{rootA})
require.NoError(t, err) require.NoError(t, err)
err = storeA.CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-a"}) err = fsm.GetStore().CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-a"})
require.NoError(t, err) require.NoError(t, err)
// Mock the ACL Resolver to return an authorizer with `service:write`. // Mock the ACL Resolver to return an authorizer with `service:write`.
@ -193,8 +196,9 @@ func TestWatchRoots_StateStoreAbandoned(t *testing.T) {
ctx := public.ContextWithToken(context.Background(), testACLToken) ctx := public.ContextWithToken(context.Background(), testACLToken)
server := NewServer(Config{ server := NewServer(Config{
GetStore: func() StateStore { return storeA }, Publisher: publisher,
Logger: hclog.NewNullLogger(), GetStore: func() StateStore { return fsm.GetStore() },
Logger: testutil.Logger(t),
ACLResolver: aclResolver, ACLResolver: aclResolver,
}) })
@ -208,7 +212,7 @@ func TestWatchRoots_StateStoreAbandoned(t *testing.T) {
mustGetRoots(t, rspCh) mustGetRoots(t, rspCh)
// Simulate a snapshot restore. // Simulate a snapshot restore.
storeB := testStateStore(t) storeB := testStateStore(t, publisher)
rootB := connect.TestCA(t, nil) rootB := connect.TestCA(t, nil)
_, err = storeB.CARootSetCAS(1, 0, structs.CARoots{rootB}) _, err = storeB.CARootSetCAS(1, 0, structs.CARoots{rootB})
@ -217,9 +221,7 @@ func TestWatchRoots_StateStoreAbandoned(t *testing.T) {
err = storeB.CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-b"}) err = storeB.CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-b"})
require.NoError(t, err) require.NoError(t, err)
server.GetStore = func() StateStore { return storeB } fsm.ReplaceStore(storeB)
storeA.Abandon()
// Expect to get the new store's roots. // Expect to get the new store's roots.
newRoots := mustGetRoots(t, rspCh) newRoots := mustGetRoots(t, rspCh)

View File

@ -43,10 +43,8 @@ func TestStore_IntegrationWithBackend(t *testing.T) {
} }
sh := snapshotHandler{producers: producers} sh := snapshotHandler{producers: producers}
handlers := map[stream.Topic]stream.SnapshotFunc{ pub := stream.NewEventPublisher(10 * time.Millisecond)
pbsubscribe.Topic_ServiceHealth: sh.Snapshot, pub.RegisterHandler(pbsubscribe.Topic_ServiceHealth, sh.Snapshot)
}
pub := stream.NewEventPublisher(handlers, 10*time.Millisecond)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()