proxycfg: split state into kind-specific types

This commit extracts all the kind-specific logic into handler types, and
keeps the generic parts on the state struct. This change should make it
easier to add new kinds, and see the implementation of each kind more
clearly.
This commit is contained in:
Daniel Nephin 2020-12-23 18:03:30 -05:00
parent cd05df7157
commit 32c15d9a88
3 changed files with 205 additions and 203 deletions

View File

@ -189,14 +189,8 @@ func (m *Manager) ensureProxyServiceLocked(ns *structs.NodeService, token string
state.Close() state.Close()
} }
var err error
state, err = newState(ns, token)
if err != nil {
return err
}
// TODO: move to a function that translates ManagerConfig->stateConfig // TODO: move to a function that translates ManagerConfig->stateConfig
state.stateConfig = stateConfig{ stateConfig := stateConfig{
logger: m.Logger.With("service_id", sid.String()), logger: m.Logger.With("service_id", sid.String()),
cache: m.Cache, cache: m.Cache,
health: m.Health, health: m.Health,
@ -205,7 +199,13 @@ func (m *Manager) ensureProxyServiceLocked(ns *structs.NodeService, token string
intentionDefaultAllow: m.IntentionDefaultAllow, intentionDefaultAllow: m.IntentionDefaultAllow,
} }
if m.TLSConfigurator != nil { if m.TLSConfigurator != nil {
state.serverSNIFn = m.TLSConfigurator.ServerSNI stateConfig.serverSNIFn = m.TLSConfigurator.ServerSNI
}
var err error
state, err = newState(ns, token, stateConfig)
if err != nil {
return err
} }
ch, err := state.Watch() ch, err := state.Watch()

View File

@ -67,10 +67,9 @@ type stateConfig struct {
// connect-proxy service. When a proxy registration is changed, the entire state // connect-proxy service. When a proxy registration is changed, the entire state
// is discarded and a new one created. // is discarded and a new one created.
type state struct { type state struct {
// TODO: un-embedd once refactor is complete logger hclog.Logger
stateConfig serviceInstance serviceInstance
// TODO: un-embed once refactor is complete handler kindHandler
serviceInstance
// cancel is set by Watch and called by Close to stop the goroutine started // cancel is set by Watch and called by Close to stop the goroutine started
// in Watch. // in Watch.
@ -136,34 +135,44 @@ func copyProxyConfig(ns *structs.NodeService) (structs.ConnectProxyConfig, error
// //
// The returned state needs its required dependencies to be set before Watch // The returned state needs its required dependencies to be set before Watch
// can be called. // can be called.
func newState(ns *structs.NodeService, token string) (*state, error) { func newState(ns *structs.NodeService, token string, config stateConfig) (*state, error) {
switch ns.Kind { // 10 is fairly arbitrary here but allow for the 3 mandatory and a
case structs.ServiceKindConnectProxy: // reasonable number of upstream watches to all deliver their initial
case structs.ServiceKindTerminatingGateway: // messages in parallel without blocking the cache.Notify loops. It's not a
case structs.ServiceKindMeshGateway: // huge deal if we do for a short period so we don't need to be more
case structs.ServiceKindIngressGateway: // conservative to handle larger numbers of upstreams correctly but gives
default: // some head room for normal operation to be non-blocking in most typical
return nil, errors.New("not a connect-proxy, terminating-gateway, mesh-gateway, or ingress-gateway") // cases.
} ch := make(chan cache.UpdateEvent, 10)
s, err := newServiceInstanceFromNodeService(ns, token) s, err := newServiceInstanceFromNodeService(ns, token)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &state{ var handler kindHandler
serviceInstance: s, switch ns.Kind {
case structs.ServiceKindConnectProxy:
handler = &handlerConnectProxy{stateConfig: config, serviceInstance: s, ch: ch}
case structs.ServiceKindTerminatingGateway:
config.logger = config.logger.Named(logging.TerminatingGateway)
handler = &handlerTerminatingGateway{stateConfig: config, serviceInstance: s, ch: ch}
case structs.ServiceKindMeshGateway:
config.logger = config.logger.Named(logging.MeshGateway)
handler = &handlerMeshGateway{stateConfig: config, serviceInstance: s, ch: ch}
case structs.ServiceKindIngressGateway:
handler = &handlerIngressGateway{stateConfig: config, serviceInstance: s, ch: ch}
default:
return nil, errors.New("not a connect-proxy, terminating-gateway, mesh-gateway, or ingress-gateway")
}
// 10 is fairly arbitrary here but allow for the 3 mandatory and a return &state{
// reasonable number of upstream watches to all deliver their initial logger: config.logger.With("proxy", s.proxyID, "kind", s.kind),
// messages in parallel without blocking the cache.Notify loops. It's not a serviceInstance: s,
// huge deal if we do for a short period so we don't need to be more handler: handler,
// conservative to handle larger numbers of upstreams correctly but gives ch: ch,
// some head room for normal operation to be non-blocking in most typical snapCh: make(chan ConfigSnapshot, 1),
// cases. reqCh: make(chan chan *ConfigSnapshot, 1),
ch: make(chan cache.UpdateEvent, 10),
snapCh: make(chan ConfigSnapshot, 1),
reqCh: make(chan chan *ConfigSnapshot, 1),
}, nil }, nil
} }
@ -196,6 +205,11 @@ func newServiceInstanceFromNodeService(ns *structs.NodeService, token string) (s
}, nil }, nil
} }
type kindHandler interface {
initialize(ctx context.Context) (ConfigSnapshot, error)
handleUpdate(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error
}
// Watch initialized watches on all necessary cache data for the current proxy // Watch initialized watches on all necessary cache data for the current proxy
// registration state and returns a chan to observe updates to the // registration state and returns a chan to observe updates to the
// ConfigSnapshot that contains all necessary config state. The chan is closed // ConfigSnapshot that contains all necessary config state. The chan is closed
@ -204,8 +218,7 @@ func (s *state) Watch() (<-chan ConfigSnapshot, error) {
var ctx context.Context var ctx context.Context
ctx, s.cancel = context.WithCancel(context.Background()) ctx, s.cancel = context.WithCancel(context.Background())
snap := s.initialConfigSnapshot() snap, err := s.handler.initialize(ctx)
err := s.initWatches(ctx, &snap)
if err != nil { if err != nil {
s.cancel() s.cancel()
return nil, err return nil, err
@ -224,23 +237,21 @@ func (s *state) Close() error {
return nil return nil
} }
// initWatches sets up the watches needed for the particular service type handler struct {
func (s *state) initWatches(ctx context.Context, snap *ConfigSnapshot) error { stateConfig // TODO: un-embed
switch s.kind { serviceInstance // TODO: un-embed
case structs.ServiceKindConnectProxy: ch chan cache.UpdateEvent
return s.initWatchesConnectProxy(ctx, snap)
case structs.ServiceKindTerminatingGateway:
return s.initWatchesTerminatingGateway(ctx)
case structs.ServiceKindMeshGateway:
return s.initWatchesMeshGateway(ctx)
case structs.ServiceKindIngressGateway:
return s.initWatchesIngressGateway(ctx)
default:
return fmt.Errorf("Unsupported service kind")
}
} }
func (s *state) watchMeshGateway(ctx context.Context, dc string, upstreamID string) error { type handlerMeshGateway handler
type handlerTerminatingGateway handler
type handlerConnectProxy handler
type handlerIngressGateway handler
func (s *handlerUpstreams) watchMeshGateway(ctx context.Context, dc string, upstreamID string) error {
return s.cache.Notify(ctx, cachetype.InternalServiceDumpName, &structs.ServiceDumpRequest{ return s.cache.Notify(ctx, cachetype.InternalServiceDumpName, &structs.ServiceDumpRequest{
Datacenter: dc, Datacenter: dc,
QueryOptions: structs.QueryOptions{Token: s.token}, QueryOptions: structs.QueryOptions{Token: s.token},
@ -251,9 +262,40 @@ func (s *state) watchMeshGateway(ctx context.Context, dc string, upstreamID stri
}, "mesh-gateway:"+dc+":"+upstreamID, s.ch) }, "mesh-gateway:"+dc+":"+upstreamID, s.ch)
} }
// initWatchesConnectProxy sets up the watches needed based on current proxy registration type handlerUpstreams handler
func (s *handlerUpstreams) watchConnectProxyService(ctx context.Context, correlationId string, target *structs.DiscoveryTarget) error {
return s.stateConfig.cache.Notify(ctx, cachetype.HealthServicesName, &structs.ServiceSpecificRequest{
Datacenter: target.Datacenter,
QueryOptions: structs.QueryOptions{
Token: s.serviceInstance.token,
Filter: target.Subset.Filter,
},
ServiceName: target.Service,
Connect: true,
// Note that Identifier doesn't type-prefix for service any more as it's
// the default and makes metrics and other things much cleaner. It's
// simpler for us if we have the type to make things unambiguous.
Source: *s.stateConfig.source,
EnterpriseMeta: *target.GetEnterpriseMetadata(),
}, correlationId, s.ch)
}
// initialize sets up the watches needed based on current proxy registration
// state. // state.
func (s *state) initWatchesConnectProxy(ctx context.Context, snap *ConfigSnapshot) error { func (s *handlerConnectProxy) initialize(ctx context.Context) (ConfigSnapshot, error) {
snap := newConfigSnapshotFromServiceInstance(s.serviceInstance, s.stateConfig)
snap.ConnectProxy.DiscoveryChain = make(map[string]*structs.CompiledDiscoveryChain)
snap.ConnectProxy.WatchedDiscoveryChains = make(map[string]context.CancelFunc)
snap.ConnectProxy.WatchedUpstreams = make(map[string]map[string]context.CancelFunc)
snap.ConnectProxy.WatchedUpstreamEndpoints = make(map[string]map[string]structs.CheckServiceNodes)
snap.ConnectProxy.WatchedGateways = make(map[string]map[string]context.CancelFunc)
snap.ConnectProxy.WatchedGatewayEndpoints = make(map[string]map[string]structs.CheckServiceNodes)
snap.ConnectProxy.WatchedServiceChecks = make(map[structs.ServiceID][]structs.CheckType)
snap.ConnectProxy.PreparedQueryEndpoints = make(map[string]structs.CheckServiceNodes)
snap.ConnectProxy.UpstreamConfig = make(map[string]*structs.Upstream)
snap.ConnectProxy.PassthroughUpstreams = make(map[string]ServicePassthroughAddrs)
// Watch for root changes // Watch for root changes
err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{
Datacenter: s.source.Datacenter, Datacenter: s.source.Datacenter,
@ -261,7 +303,7 @@ func (s *state) initWatchesConnectProxy(ctx context.Context, snap *ConfigSnapsho
Source: *s.source, Source: *s.source,
}, rootsWatchID, s.ch) }, rootsWatchID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
// Watch the leaf cert // Watch the leaf cert
@ -272,7 +314,7 @@ func (s *state) initWatchesConnectProxy(ctx context.Context, snap *ConfigSnapsho
EnterpriseMeta: s.proxyID.EnterpriseMeta, EnterpriseMeta: s.proxyID.EnterpriseMeta,
}, leafWatchID, s.ch) }, leafWatchID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
// Watch for intention updates // Watch for intention updates
@ -290,7 +332,7 @@ func (s *state) initWatchesConnectProxy(ctx context.Context, snap *ConfigSnapsho
}, },
}, intentionsWatchID, s.ch) }, intentionsWatchID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
// Watch for service check updates // Watch for service check updates
@ -299,7 +341,7 @@ func (s *state) initWatchesConnectProxy(ctx context.Context, snap *ConfigSnapsho
EnterpriseMeta: s.proxyID.EnterpriseMeta, EnterpriseMeta: s.proxyID.EnterpriseMeta,
}, svcChecksWatchIDPrefix+structs.ServiceIDString(s.proxyCfg.DestinationServiceID, &s.proxyID.EnterpriseMeta), s.ch) }, svcChecksWatchIDPrefix+structs.ServiceIDString(s.proxyCfg.DestinationServiceID, &s.proxyID.EnterpriseMeta), s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
// default the namespace to the namespace of this proxy service // default the namespace to the namespace of this proxy service
@ -314,7 +356,7 @@ func (s *state) initWatchesConnectProxy(ctx context.Context, snap *ConfigSnapsho
EnterpriseMeta: structs.NewEnterpriseMeta(s.proxyID.NamespaceOrEmpty()), EnterpriseMeta: structs.NewEnterpriseMeta(s.proxyID.NamespaceOrEmpty()),
}, intentionUpstreamsID, s.ch) }, intentionUpstreamsID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
err = s.cache.Notify(ctx, cachetype.ConfigEntryName, &structs.ConfigEntryQuery{ err = s.cache.Notify(ctx, cachetype.ConfigEntryName, &structs.ConfigEntryQuery{
@ -325,7 +367,7 @@ func (s *state) initWatchesConnectProxy(ctx context.Context, snap *ConfigSnapsho
EnterpriseMeta: *structs.DefaultEnterpriseMeta(), EnterpriseMeta: *structs.DefaultEnterpriseMeta(),
}, meshConfigEntryID, s.ch) }, meshConfigEntryID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
} }
@ -381,7 +423,7 @@ func (s *state) initWatchesConnectProxy(ctx context.Context, snap *ConfigSnapsho
Source: *s.source, Source: *s.source,
}, "upstream:"+u.Identifier(), s.ch) }, "upstream:"+u.Identifier(), s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
case structs.UpstreamDestTypeService: case structs.UpstreamDestTypeService:
@ -399,17 +441,18 @@ func (s *state) initWatchesConnectProxy(ctx context.Context, snap *ConfigSnapsho
OverrideConnectTimeout: cfg.ConnectTimeout(), OverrideConnectTimeout: cfg.ConnectTimeout(),
}, "discovery-chain:"+u.Identifier(), s.ch) }, "discovery-chain:"+u.Identifier(), s.ch)
if err != nil { if err != nil {
return fmt.Errorf("failed to watch discovery chain for %s: %v", u.Identifier(), err) return snap, fmt.Errorf("failed to watch discovery chain for %s: %v", u.Identifier(), err)
} }
default: default:
return fmt.Errorf("unknown upstream type: %q", u.DestinationType) return snap, fmt.Errorf("unknown upstream type: %q", u.DestinationType)
} }
} }
return nil
return snap, nil
} }
// reducedProxyConfig represents the basic opaque config values that are now // reducedUpstreamConfig represents the basic opaque config values that are now
// managed with the discovery chain but for backwards compatibility reasons // managed with the discovery chain but for backwards compatibility reasons
// should still affect how the proxy is configured. // should still affect how the proxy is configured.
// //
@ -430,7 +473,8 @@ func parseReducedUpstreamConfig(m map[string]interface{}) (reducedUpstreamConfig
} }
// initWatchesTerminatingGateway sets up the initial watches needed based on the terminating-gateway registration // initWatchesTerminatingGateway sets up the initial watches needed based on the terminating-gateway registration
func (s *state) initWatchesTerminatingGateway(ctx context.Context) error { func (s *handlerTerminatingGateway) initialize(ctx context.Context) (ConfigSnapshot, error) {
snap := newConfigSnapshotFromServiceInstance(s.serviceInstance, s.stateConfig)
// Watch for root changes // Watch for root changes
err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{
Datacenter: s.source.Datacenter, Datacenter: s.source.Datacenter,
@ -438,9 +482,8 @@ func (s *state) initWatchesTerminatingGateway(ctx context.Context) error {
Source: *s.source, Source: *s.source,
}, rootsWatchID, s.ch) }, rootsWatchID, s.ch)
if err != nil { if err != nil {
s.logger.Named(logging.TerminatingGateway). s.logger.Error("failed to register watch for root changes", "error", err)
Error("failed to register watch for root changes", "error", err) return snap, err
return err
} }
// Watch for the terminating-gateway's linked services // Watch for the terminating-gateway's linked services
@ -451,16 +494,29 @@ func (s *state) initWatchesTerminatingGateway(ctx context.Context) error {
EnterpriseMeta: s.proxyID.EnterpriseMeta, EnterpriseMeta: s.proxyID.EnterpriseMeta,
}, gatewayServicesWatchID, s.ch) }, gatewayServicesWatchID, s.ch)
if err != nil { if err != nil {
s.logger.Named(logging.TerminatingGateway). s.logger.Error("failed to register watch for linked services", "error", err)
Error("failed to register watch for linked services", "error", err) return snap, err
return err
} }
return nil snap.TerminatingGateway.WatchedServices = make(map[structs.ServiceName]context.CancelFunc)
snap.TerminatingGateway.WatchedIntentions = make(map[structs.ServiceName]context.CancelFunc)
snap.TerminatingGateway.Intentions = make(map[structs.ServiceName]structs.Intentions)
snap.TerminatingGateway.WatchedLeaves = make(map[structs.ServiceName]context.CancelFunc)
snap.TerminatingGateway.ServiceLeaves = make(map[structs.ServiceName]*structs.IssuedCert)
snap.TerminatingGateway.WatchedConfigs = make(map[structs.ServiceName]context.CancelFunc)
snap.TerminatingGateway.ServiceConfigs = make(map[structs.ServiceName]*structs.ServiceConfigResponse)
snap.TerminatingGateway.WatchedResolvers = make(map[structs.ServiceName]context.CancelFunc)
snap.TerminatingGateway.ServiceResolvers = make(map[structs.ServiceName]*structs.ServiceResolverConfigEntry)
snap.TerminatingGateway.ServiceResolversSet = make(map[structs.ServiceName]bool)
snap.TerminatingGateway.ServiceGroups = make(map[structs.ServiceName]structs.CheckServiceNodes)
snap.TerminatingGateway.GatewayServices = make(map[structs.ServiceName]structs.GatewayService)
snap.TerminatingGateway.HostnameServices = make(map[structs.ServiceName]structs.CheckServiceNodes)
return snap, nil
} }
// initWatchesMeshGateway sets up the watches needed based on the current mesh gateway registration // initWatchesMeshGateway sets up the watches needed based on the current mesh gateway registration
func (s *state) initWatchesMeshGateway(ctx context.Context) error { func (s *handlerMeshGateway) initialize(ctx context.Context) (ConfigSnapshot, error) {
snap := newConfigSnapshotFromServiceInstance(s.serviceInstance, s.stateConfig)
// Watch for root changes // Watch for root changes
err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{
Datacenter: s.source.Datacenter, Datacenter: s.source.Datacenter,
@ -468,7 +524,7 @@ func (s *state) initWatchesMeshGateway(ctx context.Context) error {
Source: *s.source, Source: *s.source,
}, rootsWatchID, s.ch) }, rootsWatchID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
// Watch for all services // Watch for all services
@ -480,7 +536,7 @@ func (s *state) initWatchesMeshGateway(ctx context.Context) error {
}, serviceListWatchID, s.ch) }, serviceListWatchID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
if s.meta[structs.MetaWANFederationKey] == "1" { if s.meta[structs.MetaWANFederationKey] == "1" {
@ -493,7 +549,7 @@ func (s *state) initWatchesMeshGateway(ctx context.Context) error {
Source: *s.source, Source: *s.source,
}, federationStateListGatewaysWatchID, s.ch) }, federationStateListGatewaysWatchID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
err = s.health.Notify(ctx, structs.ServiceSpecificRequest{ err = s.health.Notify(ctx, structs.ServiceSpecificRequest{
@ -502,7 +558,7 @@ func (s *state) initWatchesMeshGateway(ctx context.Context) error {
ServiceName: structs.ConsulServiceName, ServiceName: structs.ConsulServiceName,
}, consulServerListWatchID, s.ch) }, consulServerListWatchID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
} }
@ -515,7 +571,7 @@ func (s *state) initWatchesMeshGateway(ctx context.Context) error {
QueryOptions: structs.QueryOptions{Token: s.token, MaxAge: 30 * time.Second}, QueryOptions: structs.QueryOptions{Token: s.token, MaxAge: 30 * time.Second},
}, datacentersWatchID, s.ch) }, datacentersWatchID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
// Once we start getting notified about the datacenters we will setup watches on the // Once we start getting notified about the datacenters we will setup watches on the
@ -529,17 +585,25 @@ func (s *state) initWatchesMeshGateway(ctx context.Context) error {
Kind: structs.ServiceResolver, Kind: structs.ServiceResolver,
EnterpriseMeta: *structs.WildcardEnterpriseMeta(), EnterpriseMeta: *structs.WildcardEnterpriseMeta(),
}, serviceResolversWatchID, s.ch) }, serviceResolversWatchID, s.ch)
if err != nil { if err != nil {
s.logger.Named(logging.MeshGateway). s.logger.Named(logging.MeshGateway).
Error("failed to register watch for service-resolver config entries", "error", err) Error("failed to register watch for service-resolver config entries", "error", err)
return err return snap, err
} }
return err snap.MeshGateway.WatchedServices = make(map[structs.ServiceName]context.CancelFunc)
snap.MeshGateway.WatchedDatacenters = make(map[string]context.CancelFunc)
snap.MeshGateway.ServiceGroups = make(map[structs.ServiceName]structs.CheckServiceNodes)
snap.MeshGateway.GatewayGroups = make(map[string]structs.CheckServiceNodes)
snap.MeshGateway.ServiceResolvers = make(map[structs.ServiceName]*structs.ServiceResolverConfigEntry)
snap.MeshGateway.HostnameDatacenters = make(map[string]structs.CheckServiceNodes)
// there is no need to initialize the map of service resolvers as we
// fully rebuild it every time we get updates
return snap, err
} }
func (s *state) initWatchesIngressGateway(ctx context.Context) error { func (s *handlerIngressGateway) initialize(ctx context.Context) (ConfigSnapshot, error) {
snap := newConfigSnapshotFromServiceInstance(s.serviceInstance, s.stateConfig)
// Watch for root changes // Watch for root changes
err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{
Datacenter: s.source.Datacenter, Datacenter: s.source.Datacenter,
@ -547,7 +611,7 @@ func (s *state) initWatchesIngressGateway(ctx context.Context) error {
Source: *s.source, Source: *s.source,
}, rootsWatchID, s.ch) }, rootsWatchID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
// Watch this ingress gateway's config entry // Watch this ingress gateway's config entry
@ -559,7 +623,7 @@ func (s *state) initWatchesIngressGateway(ctx context.Context) error {
EnterpriseMeta: s.proxyID.EnterpriseMeta, EnterpriseMeta: s.proxyID.EnterpriseMeta,
}, gatewayConfigWatchID, s.ch) }, gatewayConfigWatchID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
// Watch the ingress-gateway's list of upstreams // Watch the ingress-gateway's list of upstreams
@ -570,14 +634,21 @@ func (s *state) initWatchesIngressGateway(ctx context.Context) error {
EnterpriseMeta: s.proxyID.EnterpriseMeta, EnterpriseMeta: s.proxyID.EnterpriseMeta,
}, gatewayServicesWatchID, s.ch) }, gatewayServicesWatchID, s.ch)
if err != nil { if err != nil {
return err return snap, err
} }
return nil snap.IngressGateway.WatchedDiscoveryChains = make(map[string]context.CancelFunc)
snap.IngressGateway.DiscoveryChain = make(map[string]*structs.CompiledDiscoveryChain)
snap.IngressGateway.WatchedUpstreams = make(map[string]map[string]context.CancelFunc)
snap.IngressGateway.WatchedUpstreamEndpoints = make(map[string]map[string]structs.CheckServiceNodes)
snap.IngressGateway.WatchedGateways = make(map[string]map[string]context.CancelFunc)
snap.IngressGateway.WatchedGatewayEndpoints = make(map[string]map[string]structs.CheckServiceNodes)
return snap, nil
} }
func (s *state) initialConfigSnapshot() ConfigSnapshot { func newConfigSnapshotFromServiceInstance(s serviceInstance, config stateConfig) ConfigSnapshot {
snap := ConfigSnapshot{ // TODO: use serviceInstance type in ConfigSnapshot
return ConfigSnapshot{
Kind: s.kind, Kind: s.kind,
Service: s.service, Service: s.service,
ProxyID: s.proxyID, ProxyID: s.proxyID,
@ -586,56 +657,10 @@ func (s *state) initialConfigSnapshot() ConfigSnapshot {
ServiceMeta: s.meta, ServiceMeta: s.meta,
TaggedAddresses: s.taggedAddresses, TaggedAddresses: s.taggedAddresses,
Proxy: s.proxyCfg, Proxy: s.proxyCfg,
Datacenter: s.source.Datacenter, Datacenter: config.source.Datacenter,
ServerSNIFn: s.serverSNIFn, ServerSNIFn: config.serverSNIFn,
IntentionDefaultAllow: s.intentionDefaultAllow, IntentionDefaultAllow: config.intentionDefaultAllow,
} }
switch s.kind {
case structs.ServiceKindConnectProxy:
snap.ConnectProxy.DiscoveryChain = make(map[string]*structs.CompiledDiscoveryChain)
snap.ConnectProxy.WatchedDiscoveryChains = make(map[string]context.CancelFunc)
snap.ConnectProxy.WatchedUpstreams = make(map[string]map[string]context.CancelFunc)
snap.ConnectProxy.WatchedUpstreamEndpoints = make(map[string]map[string]structs.CheckServiceNodes)
snap.ConnectProxy.WatchedGateways = make(map[string]map[string]context.CancelFunc)
snap.ConnectProxy.WatchedGatewayEndpoints = make(map[string]map[string]structs.CheckServiceNodes)
snap.ConnectProxy.WatchedServiceChecks = make(map[structs.ServiceID][]structs.CheckType)
snap.ConnectProxy.PreparedQueryEndpoints = make(map[string]structs.CheckServiceNodes)
snap.ConnectProxy.UpstreamConfig = make(map[string]*structs.Upstream)
snap.ConnectProxy.PassthroughUpstreams = make(map[string]ServicePassthroughAddrs)
case structs.ServiceKindTerminatingGateway:
snap.TerminatingGateway.WatchedServices = make(map[structs.ServiceName]context.CancelFunc)
snap.TerminatingGateway.WatchedIntentions = make(map[structs.ServiceName]context.CancelFunc)
snap.TerminatingGateway.Intentions = make(map[structs.ServiceName]structs.Intentions)
snap.TerminatingGateway.WatchedLeaves = make(map[structs.ServiceName]context.CancelFunc)
snap.TerminatingGateway.ServiceLeaves = make(map[structs.ServiceName]*structs.IssuedCert)
snap.TerminatingGateway.WatchedConfigs = make(map[structs.ServiceName]context.CancelFunc)
snap.TerminatingGateway.ServiceConfigs = make(map[structs.ServiceName]*structs.ServiceConfigResponse)
snap.TerminatingGateway.WatchedResolvers = make(map[structs.ServiceName]context.CancelFunc)
snap.TerminatingGateway.ServiceResolvers = make(map[structs.ServiceName]*structs.ServiceResolverConfigEntry)
snap.TerminatingGateway.ServiceResolversSet = make(map[structs.ServiceName]bool)
snap.TerminatingGateway.ServiceGroups = make(map[structs.ServiceName]structs.CheckServiceNodes)
snap.TerminatingGateway.GatewayServices = make(map[structs.ServiceName]structs.GatewayService)
snap.TerminatingGateway.HostnameServices = make(map[structs.ServiceName]structs.CheckServiceNodes)
case structs.ServiceKindMeshGateway:
snap.MeshGateway.WatchedServices = make(map[structs.ServiceName]context.CancelFunc)
snap.MeshGateway.WatchedDatacenters = make(map[string]context.CancelFunc)
snap.MeshGateway.ServiceGroups = make(map[structs.ServiceName]structs.CheckServiceNodes)
snap.MeshGateway.GatewayGroups = make(map[string]structs.CheckServiceNodes)
snap.MeshGateway.ServiceResolvers = make(map[structs.ServiceName]*structs.ServiceResolverConfigEntry)
snap.MeshGateway.HostnameDatacenters = make(map[string]structs.CheckServiceNodes)
// there is no need to initialize the map of service resolvers as we
// fully rebuild it every time we get updates
case structs.ServiceKindIngressGateway:
snap.IngressGateway.WatchedDiscoveryChains = make(map[string]context.CancelFunc)
snap.IngressGateway.DiscoveryChain = make(map[string]*structs.CompiledDiscoveryChain)
snap.IngressGateway.WatchedUpstreams = make(map[string]map[string]context.CancelFunc)
snap.IngressGateway.WatchedUpstreamEndpoints = make(map[string]map[string]structs.CheckServiceNodes)
snap.IngressGateway.WatchedGateways = make(map[string]map[string]context.CancelFunc)
snap.IngressGateway.WatchedGatewayEndpoints = make(map[string]map[string]structs.CheckServiceNodes)
}
return snap
} }
func (s *state) run(ctx context.Context, snap *ConfigSnapshot) { func (s *state) run(ctx context.Context, snap *ConfigSnapshot) {
@ -659,7 +684,7 @@ func (s *state) run(ctx context.Context, snap *ConfigSnapshot) {
case u := <-s.ch: case u := <-s.ch:
s.logger.Trace("A blocking query returned; handling snapshot update") s.logger.Trace("A blocking query returned; handling snapshot update")
if err := s.handleUpdate(ctx, u, snap); err != nil { if err := s.handler.handleUpdate(ctx, u, snap); err != nil {
s.logger.Error("Failed to handle update from watch", s.logger.Error("Failed to handle update from watch",
"id", u.CorrelationID, "error", err, "id", u.CorrelationID, "error", err,
) )
@ -671,9 +696,7 @@ func (s *state) run(ctx context.Context, snap *ConfigSnapshot) {
// etc on future updates. // etc on future updates.
snapCopy, err := snap.Clone() snapCopy, err := snap.Clone()
if err != nil { if err != nil {
s.logger.Error("Failed to copy config snapshot for proxy", s.logger.Error("Failed to copy config snapshot for proxy", "error", err)
"error", err,
)
continue continue
} }
@ -719,9 +742,7 @@ func (s *state) run(ctx context.Context, snap *ConfigSnapshot) {
// etc on future updates. // etc on future updates.
snapCopy, err := snap.Clone() snapCopy, err := snap.Clone()
if err != nil { if err != nil {
s.logger.Error("Failed to copy config snapshot for proxy", s.logger.Error("Failed to copy config snapshot for proxy", "error", err)
"error", err,
)
continue continue
} }
replyCh <- snapCopy replyCh <- snapCopy
@ -748,22 +769,7 @@ func (s *state) run(ctx context.Context, snap *ConfigSnapshot) {
} }
} }
func (s *state) handleUpdate(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { func (s *handlerConnectProxy) handleUpdate(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error {
switch s.kind {
case structs.ServiceKindConnectProxy:
return s.handleUpdateConnectProxy(ctx, u, snap)
case structs.ServiceKindTerminatingGateway:
return s.handleUpdateTerminatingGateway(ctx, u, snap)
case structs.ServiceKindMeshGateway:
return s.handleUpdateMeshGateway(ctx, u, snap)
case structs.ServiceKindIngressGateway:
return s.handleUpdateIngressGateway(ctx, u, snap)
default:
return fmt.Errorf("Unsupported service kind")
}
}
func (s *state) handleUpdateConnectProxy(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error {
if u.Err != nil { if u.Err != nil {
return fmt.Errorf("error filling agent cache: %v", u.Err) return fmt.Errorf("error filling agent cache: %v", u.Err)
} }
@ -838,7 +844,7 @@ func (s *state) handleUpdateConnectProxy(ctx context.Context, u cache.UpdateEven
cfg: cfg, cfg: cfg,
meshGateway: meshGateway, meshGateway: meshGateway,
} }
err = s.watchDiscoveryChain(ctx, snap, watchOpts) err = (*handlerUpstreams)(s).watchDiscoveryChain(ctx, snap, watchOpts)
if err != nil { if err != nil {
return fmt.Errorf("failed to watch discovery chain for %s: %v", svc.String(), err) return fmt.Errorf("failed to watch discovery chain for %s: %v", svc.String(), err)
} }
@ -927,12 +933,12 @@ func (s *state) handleUpdateConnectProxy(ctx context.Context, u cache.UpdateEven
snap.ConnectProxy.MeshConfigSet = true snap.ConnectProxy.MeshConfigSet = true
default: default:
return s.handleUpdateUpstreams(ctx, u, snap) return (*handlerUpstreams)(s).handleUpdateUpstreams(ctx, u, snap)
} }
return nil return nil
} }
func (s *state) handleUpdateUpstreams(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { func (s *handlerUpstreams) handleUpdateUpstreams(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error {
if u.Err != nil { if u.Err != nil {
return fmt.Errorf("error filling agent cache: %v", u.Err) return fmt.Errorf("error filling agent cache: %v", u.Err)
} }
@ -1049,7 +1055,7 @@ func removeColonPrefix(s string) (string, string, bool) {
return s[0:idx], s[idx+1:], true return s[0:idx], s[idx+1:], true
} }
func (s *state) resetWatchesFromChain( func (s *handlerUpstreams) resetWatchesFromChain(
ctx context.Context, ctx context.Context,
id string, id string,
chain *structs.CompiledDiscoveryChain, chain *structs.CompiledDiscoveryChain,
@ -1196,7 +1202,7 @@ type targetWatchOpts struct {
entMeta *structs.EnterpriseMeta entMeta *structs.EnterpriseMeta
} }
func (s *state) watchUpstreamTarget(ctx context.Context, snap *ConfigSnapshotUpstreams, opts targetWatchOpts) error { func (s *handlerUpstreams) watchUpstreamTarget(ctx context.Context, snap *ConfigSnapshotUpstreams, opts targetWatchOpts) error {
s.logger.Trace("initializing watch of target", s.logger.Trace("initializing watch of target",
"upstream", opts.upstreamID, "upstream", opts.upstreamID,
"chain", opts.service, "chain", opts.service,
@ -1233,11 +1239,11 @@ func (s *state) watchUpstreamTarget(ctx context.Context, snap *ConfigSnapshotUps
return nil return nil
} }
func (s *state) handleUpdateTerminatingGateway(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { func (s *handlerTerminatingGateway) handleUpdate(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error {
if u.Err != nil { if u.Err != nil {
return fmt.Errorf("error filling agent cache: %v", u.Err) return fmt.Errorf("error filling agent cache: %v", u.Err)
} }
logger := s.logger.Named(logging.TerminatingGateway) logger := s.logger
switch { switch {
case u.CorrelationID == rootsWatchID: case u.CorrelationID == rootsWatchID:
@ -1461,7 +1467,7 @@ func (s *state) handleUpdateTerminatingGateway(ctx context.Context, u cache.Upda
if len(resp.Nodes) > 0 { if len(resp.Nodes) > 0 {
snap.TerminatingGateway.ServiceGroups[sn] = resp.Nodes snap.TerminatingGateway.ServiceGroups[sn] = resp.Nodes
snap.TerminatingGateway.HostnameServices[sn] = hostnameEndpoints( snap.TerminatingGateway.HostnameServices[sn] = hostnameEndpoints(
s.logger.Named(logging.TerminatingGateway), snap.Datacenter, resp.Nodes) s.logger, snap.Datacenter, resp.Nodes)
} }
// Store leaf cert for watched service // Store leaf cert for watched service
@ -1519,7 +1525,7 @@ func (s *state) handleUpdateTerminatingGateway(ctx context.Context, u cache.Upda
return nil return nil
} }
func (s *state) handleUpdateMeshGateway(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { func (s *handlerMeshGateway) handleUpdate(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error {
if u.Err != nil { if u.Err != nil {
return fmt.Errorf("error filling agent cache: %v", u.Err) return fmt.Errorf("error filling agent cache: %v", u.Err)
} }
@ -1722,7 +1728,7 @@ func (s *state) handleUpdateMeshGateway(ctx context.Context, u cache.UpdateEvent
return nil return nil
} }
func (s *state) handleUpdateIngressGateway(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { func (s *handlerIngressGateway) handleUpdate(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error {
if u.Err != nil { if u.Err != nil {
return fmt.Errorf("error filling agent cache: %v", u.Err) return fmt.Errorf("error filling agent cache: %v", u.Err)
} }
@ -1770,7 +1776,7 @@ func (s *state) handleUpdateIngressGateway(ctx context.Context, u cache.UpdateEv
namespace: u.DestinationNamespace, namespace: u.DestinationNamespace,
datacenter: s.source.Datacenter, datacenter: s.source.Datacenter,
} }
err := s.watchDiscoveryChain(ctx, snap, watchOpts) err := (*handlerUpstreams)(s).watchDiscoveryChain(ctx, snap, watchOpts)
if err != nil { if err != nil {
return fmt.Errorf("failed to watch discovery chain for %s: %v", u.Identifier(), err) return fmt.Errorf("failed to watch discovery chain for %s: %v", u.Identifier(), err)
} }
@ -1798,7 +1804,7 @@ func (s *state) handleUpdateIngressGateway(ctx context.Context, u cache.UpdateEv
} }
default: default:
return s.handleUpdateUpstreams(ctx, u, snap) return (*handlerUpstreams)(s).handleUpdateUpstreams(ctx, u, snap)
} }
return nil return nil
@ -1831,7 +1837,7 @@ type discoveryChainWatchOpts struct {
meshGateway structs.MeshGatewayConfig meshGateway structs.MeshGatewayConfig
} }
func (s *state) watchDiscoveryChain(ctx context.Context, snap *ConfigSnapshot, opts discoveryChainWatchOpts) error { func (s *handlerUpstreams) watchDiscoveryChain(ctx context.Context, snap *ConfigSnapshot, opts discoveryChainWatchOpts) error {
if _, ok := snap.ConnectProxy.WatchedDiscoveryChains[opts.id]; ok { if _, ok := snap.ConnectProxy.WatchedDiscoveryChains[opts.id]; ok {
return nil return nil
} }
@ -1865,7 +1871,7 @@ func (s *state) watchDiscoveryChain(ctx context.Context, snap *ConfigSnapshot, o
return nil return nil
} }
func (s *state) generateIngressDNSSANs(snap *ConfigSnapshot) []string { func (s *handlerIngressGateway) generateIngressDNSSANs(snap *ConfigSnapshot) []string {
// Update our leaf cert watch with wildcard entries for our DNS domains as well as any // Update our leaf cert watch with wildcard entries for our DNS domains as well as any
// configured custom hostnames from the service. // configured custom hostnames from the service.
if !snap.IngressGateway.TLSEnabled { if !snap.IngressGateway.TLSEnabled {
@ -1902,7 +1908,7 @@ func (s *state) generateIngressDNSSANs(snap *ConfigSnapshot) []string {
return dnsNames return dnsNames
} }
func (s *state) watchIngressLeafCert(ctx context.Context, snap *ConfigSnapshot) error { func (s *handlerIngressGateway) watchIngressLeafCert(ctx context.Context, snap *ConfigSnapshot) error {
if !snap.IngressGateway.TLSSet || !snap.IngressGateway.HostsSet { if !snap.IngressGateway.TLSSet || !snap.IngressGateway.HostsSet {
return nil return nil
} }
@ -1951,12 +1957,13 @@ func (s *state) Changed(ns *structs.NodeService, token string) bool {
s.logger.Warn("Failed to parse proxy config and will treat the new service as unchanged") s.logger.Warn("Failed to parse proxy config and will treat the new service as unchanged")
} }
return ns.Kind != s.kind || i := s.serviceInstance
s.proxyID != ns.CompoundServiceID() || return ns.Kind != i.kind ||
s.address != ns.Address || i.proxyID != ns.CompoundServiceID() ||
s.port != ns.Port || i.address != ns.Address ||
!reflect.DeepEqual(s.proxyCfg, proxyCfg) || i.port != ns.Port ||
s.token != token !reflect.DeepEqual(i.proxyCfg, proxyCfg) ||
i.token != token
} }
// hostnameEndpoints returns all CheckServiceNodes that have hostnames instead of IPs as the address. // hostnameEndpoints returns all CheckServiceNodes that have hostnames instead of IPs as the address.

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/cache"
@ -115,7 +116,7 @@ func TestStateChanged(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
require := require.New(t) require := require.New(t)
state, err := newState(tt.ns, tt.token) state, err := newState(tt.ns, tt.token, stateConfig{logger: hclog.New(nil)})
require.NoError(err) require.NoError(err)
otherNS, otherToken := tt.mutate(*tt.ns, tt.token) otherNS, otherToken := tt.mutate(*tt.ns, tt.token)
require.Equal(tt.want, state.Changed(otherNS, otherToken)) require.Equal(tt.want, state.Changed(otherNS, otherToken))
@ -2125,7 +2126,19 @@ func TestState_WatchesAndUpdates(t *testing.T) {
for name, tc := range cases { for name, tc := range cases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
state, err := newState(&tc.ns, "") cn := newTestCacheNotifier()
state, err := newState(&tc.ns, "", stateConfig{
logger: testutil.Logger(t),
cache: cn,
health: &health.Client{Cache: cn, CacheName: cachetype.HealthServicesName},
source: &structs.QuerySource{
Datacenter: tc.sourceDC,
},
dnsConfig: DNSConfig{
Domain: "consul.",
AltDomain: "alt.consul.",
},
})
// verify building the initial state worked // verify building the initial state worked
require.NoError(t, err) require.NoError(t, err)
@ -2134,30 +2147,12 @@ func TestState_WatchesAndUpdates(t *testing.T) {
// setup the test logger to use the t.Log // setup the test logger to use the t.Log
state.logger = testutil.Logger(t) state.logger = testutil.Logger(t)
// setup a new testing cache notifier
cn := newTestCacheNotifier()
state.cache = cn
state.health = &health.Client{Cache: cn, CacheName: cachetype.HealthServicesName}
// setup the local datacenter information
state.source = &structs.QuerySource{
Datacenter: tc.sourceDC,
}
state.dnsConfig = DNSConfig{
Domain: "consul.",
AltDomain: "alt.consul.",
}
// setup the ctx as initWatches expects this to be there // setup the ctx as initWatches expects this to be there
var ctx context.Context var ctx context.Context
ctx, state.cancel = context.WithCancel(context.Background()) ctx, state.cancel = context.WithCancel(context.Background())
// get the initial configuration snapshot snap, err := state.handler.initialize(ctx)
snap := state.initialConfigSnapshot() require.NoError(t, err)
// ensure the initial watch setup did not error
require.NoError(t, state.initWatches(ctx, &snap))
//-------------------------------------------------------------------- //--------------------------------------------------------------------
// //
@ -2184,7 +2179,7 @@ func TestState_WatchesAndUpdates(t *testing.T) {
// therefore we just tell it about the updates // therefore we just tell it about the updates
for eveIdx, event := range stage.events { for eveIdx, event := range stage.events {
require.True(t, t.Run(fmt.Sprintf("update-%d", eveIdx), func(t *testing.T) { require.True(t, t.Run(fmt.Sprintf("update-%d", eveIdx), func(t *testing.T) {
require.NoError(t, state.handleUpdate(ctx, event, &snap)) require.NoError(t, state.handler.handleUpdate(ctx, event, &snap))
})) }))
} }