diff --git a/agent/proxycfg/manager.go b/agent/proxycfg/manager.go index 4858d1f336..d0ac677f1a 100644 --- a/agent/proxycfg/manager.go +++ b/agent/proxycfg/manager.go @@ -195,13 +195,15 @@ func (m *Manager) ensureProxyServiceLocked(ns *structs.NodeService, token string return err } - // Set the necessary dependencies - state.logger = m.Logger.With("service_id", sid.String()) - state.cache = m.Cache - state.health = m.Health - state.source = m.Source - state.dnsConfig = m.DNSConfig - state.intentionDefaultAllow = m.IntentionDefaultAllow + // TODO: move to a function that translates ManagerConfig->stateConfig + state.stateConfig = stateConfig{ + logger: m.Logger.With("service_id", sid.String()), + cache: m.Cache, + health: m.Health, + source: m.Source, + dnsConfig: m.DNSConfig, + intentionDefaultAllow: m.IntentionDefaultAllow, + } if m.TLSConfigurator != nil { state.serverSNIFn = m.TLSConfigurator.ServerSNI } diff --git a/agent/proxycfg/state.go b/agent/proxycfg/state.go index 64bfc424e3..65fe1fb78f 100644 --- a/agent/proxycfg/state.go +++ b/agent/proxycfg/state.go @@ -49,16 +49,11 @@ const ( intentionUpstreamsID = "intention-upstreams" meshConfigEntryID = "mesh" svcChecksWatchIDPrefix = cachetype.ServiceHTTPChecksName + ":" - serviceIDPrefix = string(structs.UpstreamDestTypeService) + ":" preparedQueryIDPrefix = string(structs.UpstreamDestTypePreparedQuery) + ":" defaultPreparedQueryPollInterval = 30 * time.Second ) -// state holds all the state needed to maintain the config for a registered -// connect-proxy service. When a proxy registration is changed, the entire state -// is discarded and a new one created. -type state struct { - // logger, source and cache are required to be set before calling Watch. +type stateConfig struct { logger hclog.Logger source *structs.QuerySource cache CacheNotifier @@ -66,21 +61,21 @@ type state struct { dnsConfig DNSConfig serverSNIFn ServerSNIFunc intentionDefaultAllow bool +} - // ctx and cancel store the context created during initWatches call - ctx context.Context +// state holds all the state needed to maintain the config for a registered +// connect-proxy service. When a proxy registration is changed, the entire state +// is discarded and a new one created. +type state struct { + // TODO: un-embedd once refactor is complete + stateConfig + // TODO: un-embed once refactor is complete + serviceInstance + + // cancel is set by Watch and called by Close to stop the goroutine started + // in Watch. cancel func() - kind structs.ServiceKind - service string - proxyID structs.ServiceID - address string - port int - meta map[string]string - taggedAddresses map[string]structs.ServiceAddress - proxyCfg structs.ConnectProxyConfig - token string - ch chan cache.UpdateEvent snapCh chan ConfigSnapshot reqCh chan chan *ConfigSnapshot @@ -93,6 +88,18 @@ type DNSConfig struct { type ServerSNIFunc func(dc, nodeName string) string +type serviceInstance struct { + kind structs.ServiceKind + service string + proxyID structs.ServiceID + address string + port int + meta map[string]string + taggedAddresses map[string]structs.ServiceAddress + proxyCfg structs.ConnectProxyConfig + token string +} + func copyProxyConfig(ns *structs.NodeService) (structs.ConnectProxyConfig, error) { if ns == nil { return structs.ConnectProxyConfig{}, nil @@ -139,31 +146,13 @@ func newState(ns *structs.NodeService, token string) (*state, error) { return nil, errors.New("not a connect-proxy, terminating-gateway, mesh-gateway, or ingress-gateway") } - proxyCfg, err := copyProxyConfig(ns) + s, err := newServiceInstanceFromNodeService(ns, token) if err != nil { return nil, err } - taggedAddresses := make(map[string]structs.ServiceAddress) - for k, v := range ns.TaggedAddresses { - taggedAddresses[k] = v - } - - meta := make(map[string]string) - for k, v := range ns.Meta { - meta[k] = v - } - return &state{ - kind: ns.Kind, - service: ns.Service, - proxyID: ns.CompoundServiceID(), - address: ns.Address, - port: ns.Port, - meta: meta, - taggedAddresses: taggedAddresses, - proxyCfg: proxyCfg, - token: token, + serviceInstance: s, // 10 is fairly arbitrary here but allow for the 3 mandatory and a // reasonable number of upstream watches to all deliver their initial @@ -178,21 +167,51 @@ func newState(ns *structs.NodeService, token string) (*state, error) { }, nil } +func newServiceInstanceFromNodeService(ns *structs.NodeService, token string) (serviceInstance, error) { + proxyCfg, err := copyProxyConfig(ns) + if err != nil { + return serviceInstance{}, err + } + + taggedAddresses := make(map[string]structs.ServiceAddress) + for k, v := range ns.TaggedAddresses { + taggedAddresses[k] = v + } + + meta := make(map[string]string) + for k, v := range ns.Meta { + meta[k] = v + } + + return serviceInstance{ + kind: ns.Kind, + service: ns.Service, + proxyID: ns.CompoundServiceID(), + address: ns.Address, + port: ns.Port, + meta: meta, + taggedAddresses: taggedAddresses, + proxyCfg: proxyCfg, + token: token, + }, nil +} + // Watch initialized watches on all necessary cache data for the current proxy // registration state and returns a chan to observe updates to the // ConfigSnapshot that contains all necessary config state. The chan is closed // when the state is Closed. func (s *state) Watch() (<-chan ConfigSnapshot, error) { - s.ctx, s.cancel = context.WithCancel(context.Background()) + var ctx context.Context + ctx, s.cancel = context.WithCancel(context.Background()) snap := s.initialConfigSnapshot() - err := s.initWatches(&snap) + err := s.initWatches(ctx, &snap) if err != nil { s.cancel() return nil, err } - go s.run(&snap) + go s.run(ctx, &snap) return s.snapCh, nil } @@ -206,16 +225,16 @@ func (s *state) Close() error { } // initWatches sets up the watches needed for the particular service -func (s *state) initWatches(snap *ConfigSnapshot) error { +func (s *state) initWatches(ctx context.Context, snap *ConfigSnapshot) error { switch s.kind { case structs.ServiceKindConnectProxy: - return s.initWatchesConnectProxy(snap) + return s.initWatchesConnectProxy(ctx, snap) case structs.ServiceKindTerminatingGateway: - return s.initWatchesTerminatingGateway() + return s.initWatchesTerminatingGateway(ctx) case structs.ServiceKindMeshGateway: - return s.initWatchesMeshGateway() + return s.initWatchesMeshGateway(ctx) case structs.ServiceKindIngressGateway: - return s.initWatchesIngressGateway() + return s.initWatchesIngressGateway(ctx) default: return fmt.Errorf("Unsupported service kind") } @@ -234,9 +253,9 @@ func (s *state) watchMeshGateway(ctx context.Context, dc string, upstreamID stri // initWatchesConnectProxy sets up the watches needed based on current proxy registration // state. -func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { +func (s *state) initWatchesConnectProxy(ctx context.Context, snap *ConfigSnapshot) error { // Watch for root changes - err := s.cache.Notify(s.ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ + err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Source: *s.source, @@ -246,7 +265,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { } // Watch the leaf cert - err = s.cache.Notify(s.ctx, cachetype.ConnectCALeafName, &cachetype.ConnectCALeafRequest{ + err = s.cache.Notify(ctx, cachetype.ConnectCALeafName, &cachetype.ConnectCALeafRequest{ Datacenter: s.source.Datacenter, Token: s.token, Service: s.proxyCfg.DestinationServiceName, @@ -257,7 +276,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { } // Watch for intention updates - err = s.cache.Notify(s.ctx, cachetype.IntentionMatchName, &structs.IntentionQueryRequest{ + err = s.cache.Notify(ctx, cachetype.IntentionMatchName, &structs.IntentionQueryRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Match: &structs.IntentionQueryMatch{ @@ -275,7 +294,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { } // Watch for service check updates - err = s.cache.Notify(s.ctx, cachetype.ServiceHTTPChecksName, &cachetype.ServiceHTTPChecksRequest{ + err = s.cache.Notify(ctx, cachetype.ServiceHTTPChecksName, &cachetype.ServiceHTTPChecksRequest{ ServiceID: s.proxyCfg.DestinationServiceID, EnterpriseMeta: s.proxyID.EnterpriseMeta, }, svcChecksWatchIDPrefix+structs.ServiceIDString(s.proxyCfg.DestinationServiceID, &s.proxyID.EnterpriseMeta), s.ch) @@ -288,7 +307,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { if s.proxyCfg.Mode == structs.ProxyModeTransparent { // When in transparent proxy we will infer upstreams from intentions with this source - err := s.cache.Notify(s.ctx, cachetype.IntentionUpstreamsName, &structs.ServiceSpecificRequest{ + err := s.cache.Notify(ctx, cachetype.IntentionUpstreamsName, &structs.ServiceSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, ServiceName: s.proxyCfg.DestinationServiceName, @@ -298,7 +317,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { return err } - err = s.cache.Notify(s.ctx, cachetype.ConfigEntryName, &structs.ConfigEntryQuery{ + err = s.cache.Notify(ctx, cachetype.ConfigEntryName, &structs.ConfigEntryQuery{ Kind: structs.MeshConfig, Name: structs.MeshConfigMesh, Datacenter: s.source.Datacenter, @@ -354,7 +373,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { switch u.DestinationType { case structs.UpstreamDestTypePreparedQuery: - err = s.cache.Notify(s.ctx, cachetype.PreparedQueryName, &structs.PreparedQueryExecuteRequest{ + err = s.cache.Notify(ctx, cachetype.PreparedQueryName, &structs.PreparedQueryExecuteRequest{ Datacenter: dc, QueryOptions: structs.QueryOptions{Token: s.token, MaxAge: defaultPreparedQueryPollInterval}, QueryIDOrName: u.DestinationName, @@ -369,7 +388,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { fallthrough case "": // Treat unset as the default Service type - err = s.cache.Notify(s.ctx, cachetype.CompiledDiscoveryChainName, &structs.DiscoveryChainRequest{ + err = s.cache.Notify(ctx, cachetype.CompiledDiscoveryChainName, &structs.DiscoveryChainRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Name: u.DestinationName, @@ -411,9 +430,9 @@ func parseReducedUpstreamConfig(m map[string]interface{}) (reducedUpstreamConfig } // initWatchesTerminatingGateway sets up the initial watches needed based on the terminating-gateway registration -func (s *state) initWatchesTerminatingGateway() error { +func (s *state) initWatchesTerminatingGateway(ctx context.Context) error { // Watch for root changes - err := s.cache.Notify(s.ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ + err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Source: *s.source, @@ -425,7 +444,7 @@ func (s *state) initWatchesTerminatingGateway() error { } // Watch for the terminating-gateway's linked services - err = s.cache.Notify(s.ctx, cachetype.GatewayServicesName, &structs.ServiceSpecificRequest{ + err = s.cache.Notify(ctx, cachetype.GatewayServicesName, &structs.ServiceSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, ServiceName: s.service, @@ -441,9 +460,9 @@ func (s *state) initWatchesTerminatingGateway() error { } // initWatchesMeshGateway sets up the watches needed based on the current mesh gateway registration -func (s *state) initWatchesMeshGateway() error { +func (s *state) initWatchesMeshGateway(ctx context.Context) error { // Watch for root changes - err := s.cache.Notify(s.ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ + err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Source: *s.source, @@ -453,7 +472,7 @@ func (s *state) initWatchesMeshGateway() error { } // Watch for all services - err = s.cache.Notify(s.ctx, cachetype.CatalogServiceListName, &structs.DCSpecificRequest{ + err = s.cache.Notify(ctx, cachetype.CatalogServiceListName, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Source: *s.source, @@ -468,7 +487,7 @@ func (s *state) initWatchesMeshGateway() error { // Conveniently we can just use this service meta attribute in one // place here to set the machinery in motion and leave the conditional // behavior out of the rest of the package. - err = s.cache.Notify(s.ctx, cachetype.FederationStateListMeshGatewaysName, &structs.DCSpecificRequest{ + err = s.cache.Notify(ctx, cachetype.FederationStateListMeshGatewaysName, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Source: *s.source, @@ -477,7 +496,7 @@ func (s *state) initWatchesMeshGateway() error { return err } - err = s.health.Notify(s.ctx, structs.ServiceSpecificRequest{ + err = s.health.Notify(ctx, structs.ServiceSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, ServiceName: structs.ConsulServiceName, @@ -492,7 +511,7 @@ func (s *state) initWatchesMeshGateway() error { // cannot setup those watches until we know what the services are. from the service list // watch above - err = s.cache.Notify(s.ctx, cachetype.CatalogDatacentersName, &structs.DatacentersRequest{ + err = s.cache.Notify(ctx, cachetype.CatalogDatacentersName, &structs.DatacentersRequest{ QueryOptions: structs.QueryOptions{Token: s.token, MaxAge: 30 * time.Second}, }, datacentersWatchID, s.ch) if err != nil { @@ -504,7 +523,7 @@ func (s *state) initWatchesMeshGateway() error { // know what they are yet. // Watch service-resolvers so we can setup service subset clusters - err = s.cache.Notify(s.ctx, cachetype.ConfigEntriesName, &structs.ConfigEntryQuery{ + err = s.cache.Notify(ctx, cachetype.ConfigEntriesName, &structs.ConfigEntryQuery{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Kind: structs.ServiceResolver, @@ -520,9 +539,9 @@ func (s *state) initWatchesMeshGateway() error { return err } -func (s *state) initWatchesIngressGateway() error { +func (s *state) initWatchesIngressGateway(ctx context.Context) error { // Watch for root changes - err := s.cache.Notify(s.ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ + err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Source: *s.source, @@ -532,7 +551,7 @@ func (s *state) initWatchesIngressGateway() error { } // Watch this ingress gateway's config entry - err = s.cache.Notify(s.ctx, cachetype.ConfigEntryName, &structs.ConfigEntryQuery{ + err = s.cache.Notify(ctx, cachetype.ConfigEntryName, &structs.ConfigEntryQuery{ Kind: structs.IngressGateway, Name: s.service, Datacenter: s.source.Datacenter, @@ -544,7 +563,7 @@ func (s *state) initWatchesIngressGateway() error { } // Watch the ingress-gateway's list of upstreams - err = s.cache.Notify(s.ctx, cachetype.GatewayServicesName, &structs.ServiceSpecificRequest{ + err = s.cache.Notify(ctx, cachetype.GatewayServicesName, &structs.ServiceSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, ServiceName: s.service, @@ -619,7 +638,7 @@ func (s *state) initialConfigSnapshot() ConfigSnapshot { return snap } -func (s *state) run(snap *ConfigSnapshot) { +func (s *state) run(ctx context.Context, snap *ConfigSnapshot) { // Close the channel we return from Watch when we stop so consumers can stop // watching and clean up their goroutines. It's important we do this here and // not in Close since this routine sends on this chan and so might panic if it @@ -635,12 +654,12 @@ func (s *state) run(snap *ConfigSnapshot) { for { select { - case <-s.ctx.Done(): + case <-ctx.Done(): return case u := <-s.ch: s.logger.Trace("A blocking query returned; handling snapshot update") - if err := s.handleUpdate(u, snap); err != nil { + if err := s.handleUpdate(ctx, u, snap); err != nil { s.logger.Error("Failed to handle update from watch", "id", u.CorrelationID, "error", err, ) @@ -729,22 +748,22 @@ func (s *state) run(snap *ConfigSnapshot) { } } -func (s *state) handleUpdate(u cache.UpdateEvent, snap *ConfigSnapshot) error { +func (s *state) handleUpdate(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { switch s.kind { case structs.ServiceKindConnectProxy: - return s.handleUpdateConnectProxy(u, snap) + return s.handleUpdateConnectProxy(ctx, u, snap) case structs.ServiceKindTerminatingGateway: - return s.handleUpdateTerminatingGateway(u, snap) + return s.handleUpdateTerminatingGateway(ctx, u, snap) case structs.ServiceKindMeshGateway: - return s.handleUpdateMeshGateway(u, snap) + return s.handleUpdateMeshGateway(ctx, u, snap) case structs.ServiceKindIngressGateway: - return s.handleUpdateIngressGateway(u, snap) + return s.handleUpdateIngressGateway(ctx, u, snap) default: return fmt.Errorf("Unsupported service kind") } } -func (s *state) handleUpdateConnectProxy(u cache.UpdateEvent, snap *ConfigSnapshot) error { +func (s *state) handleUpdateConnectProxy(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { if u.Err != nil { return fmt.Errorf("error filling agent cache: %v", u.Err) } @@ -819,7 +838,7 @@ func (s *state) handleUpdateConnectProxy(u cache.UpdateEvent, snap *ConfigSnapsh cfg: cfg, meshGateway: meshGateway, } - err = s.watchDiscoveryChain(snap, watchOpts) + err = s.watchDiscoveryChain(ctx, snap, watchOpts) if err != nil { return fmt.Errorf("failed to watch discovery chain for %s: %v", svc.String(), err) } @@ -908,12 +927,12 @@ func (s *state) handleUpdateConnectProxy(u cache.UpdateEvent, snap *ConfigSnapsh snap.ConnectProxy.MeshConfigSet = true default: - return s.handleUpdateUpstreams(u, snap) + return s.handleUpdateUpstreams(ctx, u, snap) } return nil } -func (s *state) handleUpdateUpstreams(u cache.UpdateEvent, snap *ConfigSnapshot) error { +func (s *state) handleUpdateUpstreams(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { if u.Err != nil { return fmt.Errorf("error filling agent cache: %v", u.Err) } @@ -939,7 +958,7 @@ func (s *state) handleUpdateUpstreams(u cache.UpdateEvent, snap *ConfigSnapshot) svc := strings.TrimPrefix(u.CorrelationID, "discovery-chain:") upstreamsSnapshot.DiscoveryChain[svc] = resp.Chain - if err := s.resetWatchesFromChain(svc, resp.Chain, upstreamsSnapshot); err != nil { + if err := s.resetWatchesFromChain(ctx, svc, resp.Chain, upstreamsSnapshot); err != nil { return err } @@ -1031,6 +1050,7 @@ func removeColonPrefix(s string) (string, string, bool) { } func (s *state) resetWatchesFromChain( + ctx context.Context, id string, chain *structs.CompiledDiscoveryChain, snap *ConfigSnapshotUpstreams, @@ -1089,7 +1109,7 @@ func (s *state) resetWatchesFromChain( datacenter: target.Datacenter, entMeta: target.GetEnterpriseMetadata(), } - err := s.watchUpstreamTarget(snap, opts) + err := s.watchUpstreamTarget(ctx, snap, opts) if err != nil { return fmt.Errorf("failed to watch target %q for upstream %q", target.ID, id) } @@ -1123,7 +1143,7 @@ func (s *state) resetWatchesFromChain( datacenter: chain.Datacenter, entMeta: &chainEntMeta, } - err := s.watchUpstreamTarget(snap, opts) + err := s.watchUpstreamTarget(ctx, snap, opts) if err != nil { return fmt.Errorf("failed to watch target %q for upstream %q", chainID, id) } @@ -1140,7 +1160,7 @@ func (s *state) resetWatchesFromChain( "datacenter", dc, ) - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.watchMeshGateway(ctx, dc, id) if err != nil { cancel() @@ -1176,7 +1196,7 @@ type targetWatchOpts struct { entMeta *structs.EnterpriseMeta } -func (s *state) watchUpstreamTarget(snap *ConfigSnapshotUpstreams, opts targetWatchOpts) error { +func (s *state) watchUpstreamTarget(ctx context.Context, snap *ConfigSnapshotUpstreams, opts targetWatchOpts) error { s.logger.Trace("initializing watch of target", "upstream", opts.upstreamID, "chain", opts.service, @@ -1188,7 +1208,7 @@ func (s *state) watchUpstreamTarget(snap *ConfigSnapshotUpstreams, opts targetWa correlationID := "upstream-target:" + opts.chainID + ":" + opts.upstreamID - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.health.Notify(ctx, structs.ServiceSpecificRequest{ Datacenter: opts.datacenter, QueryOptions: structs.QueryOptions{ @@ -1213,7 +1233,7 @@ func (s *state) watchUpstreamTarget(snap *ConfigSnapshotUpstreams, opts targetWa return nil } -func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *ConfigSnapshot) error { +func (s *state) handleUpdateTerminatingGateway(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { if u.Err != nil { return fmt.Errorf("error filling agent cache: %v", u.Err) } @@ -1244,7 +1264,7 @@ func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *Config // Watch the health endpoint to discover endpoints for the service if _, ok := snap.TerminatingGateway.WatchedServices[svc.Service]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.health.Notify(ctx, structs.ServiceSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1269,7 +1289,7 @@ func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *Config // Watch intentions with this service as their destination // The gateway will enforce intentions for connections to the service if _, ok := snap.TerminatingGateway.WatchedIntentions[svc.Service]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.IntentionMatchName, &structs.IntentionQueryRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1298,7 +1318,7 @@ func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *Config // Watch leaf certificate for the service // This cert is used to terminate mTLS connections on the service's behalf if _, ok := snap.TerminatingGateway.WatchedLeaves[svc.Service]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.ConnectCALeafName, &cachetype.ConnectCALeafRequest{ Datacenter: s.source.Datacenter, Token: s.token, @@ -1320,7 +1340,7 @@ func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *Config // Watch service configs for the service. // These are used to determine the protocol for the target service. if _, ok := snap.TerminatingGateway.WatchedConfigs[svc.Service]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.ResolvedServiceConfigName, &structs.ServiceConfigRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1342,7 +1362,7 @@ func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *Config // Watch service resolvers for the service // These are used to create clusters and endpoints for the service subsets if _, ok := snap.TerminatingGateway.WatchedResolvers[svc.Service]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.ConfigEntriesName, &structs.ConfigEntryQuery{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1498,7 +1518,7 @@ func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *Config return nil } -func (s *state) handleUpdateMeshGateway(u cache.UpdateEvent, snap *ConfigSnapshot) error { +func (s *state) handleUpdateMeshGateway(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { if u.Err != nil { return fmt.Errorf("error filling agent cache: %v", u.Err) } @@ -1542,7 +1562,7 @@ func (s *state) handleUpdateMeshGateway(u cache.UpdateEvent, snap *ConfigSnapsho svcMap[svc] = struct{}{} if _, ok := snap.MeshGateway.WatchedServices[svc]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.health.Notify(ctx, structs.ServiceSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1591,7 +1611,7 @@ func (s *state) handleUpdateMeshGateway(u cache.UpdateEvent, snap *ConfigSnapsho } if _, ok := snap.MeshGateway.WatchedDatacenters[dc]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.InternalServiceDumpName, &structs.ServiceDumpRequest{ Datacenter: dc, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1699,7 +1719,7 @@ func (s *state) handleUpdateMeshGateway(u cache.UpdateEvent, snap *ConfigSnapsho return nil } -func (s *state) handleUpdateIngressGateway(u cache.UpdateEvent, snap *ConfigSnapshot) error { +func (s *state) handleUpdateIngressGateway(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { if u.Err != nil { return fmt.Errorf("error filling agent cache: %v", u.Err) } @@ -1724,7 +1744,7 @@ func (s *state) handleUpdateIngressGateway(u cache.UpdateEvent, snap *ConfigSnap snap.IngressGateway.TLSEnabled = gatewayConf.TLS.Enabled snap.IngressGateway.TLSSet = true - if err := s.watchIngressLeafCert(snap); err != nil { + if err := s.watchIngressLeafCert(ctx, snap); err != nil { return err } @@ -1747,7 +1767,7 @@ func (s *state) handleUpdateIngressGateway(u cache.UpdateEvent, snap *ConfigSnap namespace: u.DestinationNamespace, datacenter: s.source.Datacenter, } - err := s.watchDiscoveryChain(snap, watchOpts) + err := s.watchDiscoveryChain(ctx, snap, watchOpts) if err != nil { return fmt.Errorf("failed to watch discovery chain for %s: %v", u.Identifier(), err) } @@ -1770,12 +1790,12 @@ func (s *state) handleUpdateIngressGateway(u cache.UpdateEvent, snap *ConfigSnap } } - if err := s.watchIngressLeafCert(snap); err != nil { + if err := s.watchIngressLeafCert(ctx, snap); err != nil { return err } default: - return s.handleUpdateUpstreams(u, snap) + return s.handleUpdateUpstreams(ctx, u, snap) } return nil @@ -1808,12 +1828,12 @@ type discoveryChainWatchOpts struct { meshGateway structs.MeshGatewayConfig } -func (s *state) watchDiscoveryChain(snap *ConfigSnapshot, opts discoveryChainWatchOpts) error { +func (s *state) watchDiscoveryChain(ctx context.Context, snap *ConfigSnapshot, opts discoveryChainWatchOpts) error { if _, ok := snap.ConnectProxy.WatchedDiscoveryChains[opts.id]; ok { return nil } - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.CompiledDiscoveryChainName, &structs.DiscoveryChainRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1879,7 +1899,7 @@ func (s *state) generateIngressDNSSANs(snap *ConfigSnapshot) []string { return dnsNames } -func (s *state) watchIngressLeafCert(snap *ConfigSnapshot) error { +func (s *state) watchIngressLeafCert(ctx context.Context, snap *ConfigSnapshot) error { if !snap.IngressGateway.TLSSet || !snap.IngressGateway.HostsSet { return nil } @@ -1888,7 +1908,7 @@ func (s *state) watchIngressLeafCert(snap *ConfigSnapshot) error { if snap.IngressGateway.LeafCertWatchCancel != nil { snap.IngressGateway.LeafCertWatchCancel() } - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.ConnectCALeafName, &cachetype.ConnectCALeafRequest{ Datacenter: s.source.Datacenter, Token: s.token, diff --git a/agent/proxycfg/state_test.go b/agent/proxycfg/state_test.go index 37835d54a9..ded4048179 100644 --- a/agent/proxycfg/state_test.go +++ b/agent/proxycfg/state_test.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/consul/agent/cache" cachetype "github.com/hashicorp/consul/agent/cache-types" + "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/consul/discoverychain" "github.com/hashicorp/consul/agent/rpcclient/health" "github.com/hashicorp/consul/agent/structs" @@ -2151,13 +2152,14 @@ func TestState_WatchesAndUpdates(t *testing.T) { } // setup the ctx as initWatches expects this to be there - state.ctx, state.cancel = context.WithCancel(context.Background()) + var ctx context.Context + ctx, state.cancel = context.WithCancel(context.Background()) // get the initial configuration snapshot snap := state.initialConfigSnapshot() // ensure the initial watch setup did not error - require.NoError(t, state.initWatches(&snap)) + require.NoError(t, state.initWatches(ctx, &snap)) //-------------------------------------------------------------------- // @@ -2184,7 +2186,7 @@ func TestState_WatchesAndUpdates(t *testing.T) { // therefore we just tell it about the updates for eveIdx, event := range stage.events { require.True(t, t.Run(fmt.Sprintf("update-%d", eveIdx), func(t *testing.T) { - require.NoError(t, state.handleUpdate(event, &snap)) + require.NoError(t, state.handleUpdate(ctx, event, &snap)) })) }