diff --git a/agent/proxycfg/mesh_gateway.go b/agent/proxycfg/mesh_gateway.go index 93fffdc31d..6de49b69e8 100644 --- a/agent/proxycfg/mesh_gateway.go +++ b/agent/proxycfg/mesh_gateway.go @@ -8,6 +8,7 @@ import ( "time" cachetype "github.com/hashicorp/consul/agent/cache-types" + "github.com/hashicorp/consul/agent/proxycfg/internal/watch" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib/maps" "github.com/hashicorp/consul/logging" @@ -21,6 +22,8 @@ type handlerMeshGateway struct { // initialize sets up the watches needed based on the current mesh gateway registration func (s *handlerMeshGateway) initialize(ctx context.Context) (ConfigSnapshot, error) { snap := newConfigSnapshotFromServiceInstance(s.serviceInstance, s.stateConfig) + snap.MeshGateway.WatchedConsulServers = watch.NewMap[string, structs.CheckServiceNodes]() + // Watch for root changes err := s.dataSources.CARoots.Notify(ctx, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, @@ -76,7 +79,7 @@ func (s *handlerMeshGateway) initialize(ctx context.Context) (ConfigSnapshot, er } if s.proxyID.InDefaultPartition() { - if err := s.initializeCrossDCWatches(ctx); err != nil { + if err := s.initializeCrossDCWatches(ctx, &snap); err != nil { return snap, err } } @@ -123,7 +126,7 @@ func (s *handlerMeshGateway) initialize(ctx context.Context) (ConfigSnapshot, er return snap, err } -func (s *handlerMeshGateway) initializeCrossDCWatches(ctx context.Context) error { +func (s *handlerMeshGateway) initializeCrossDCWatches(ctx context.Context, snap *ConfigSnapshot) error { if s.meta[structs.MetaWANFederationKey] == "1" { // Conveniently we can just use this service meta attribute in one // place here to set the machinery in motion and leave the conditional @@ -145,6 +148,7 @@ func (s *handlerMeshGateway) initializeCrossDCWatches(ctx context.Context) error if err != nil { return err } + snap.MeshGateway.WatchedConsulServers.InitWatch(structs.ConsulServiceName, nil) } err := s.dataSources.Datacenters.Notify(ctx, &structs.DatacentersRequest{ @@ -325,7 +329,6 @@ func (s *handlerMeshGateway) handleUpdate(ctx context.Context, u UpdateEvent, sn return fmt.Errorf("invalid type for response: %T", u.Result) } - // Do some initial sanity checks to avoid doing something dumb. for _, csn := range resp.Nodes { if csn.Service.Service != structs.ConsulServiceName { return fmt.Errorf("expected service name %q but got %q", @@ -337,7 +340,7 @@ func (s *handlerMeshGateway) handleUpdate(ctx context.Context, u UpdateEvent, sn } } - snap.MeshGateway.ConsulServers = resp.Nodes + snap.MeshGateway.WatchedConsulServers.Set(structs.ConsulServiceName, resp.Nodes) case exportedServiceListWatchID: exportedServices, ok := u.Result.(*structs.IndexedExportedServiceList) @@ -463,17 +466,55 @@ func (s *handlerMeshGateway) handleUpdate(ctx context.Context, u UpdateEvent, sn return fmt.Errorf("invalid type for response: %T", u.Result) } - if resp.Entry != nil { - meshConf, ok := resp.Entry.(*structs.MeshConfigEntry) - if !ok { - return fmt.Errorf("invalid type for config entry: %T", resp.Entry) - } - snap.MeshGateway.MeshConfig = meshConf - } else { + if resp.Entry == nil { snap.MeshGateway.MeshConfig = nil + + // We avoid managing server watches when WAN federation is enabled since it + // always requires server watches. + if s.meta[structs.MetaWANFederationKey] != "1" { + // If the entry was deleted we cancel watches that may have existed because of + // PeerThroughMeshGateways being set in the past. + snap.MeshGateway.WatchedConsulServers.CancelWatch(structs.ConsulServiceName) + } + + snap.MeshGateway.MeshConfigSet = true + return nil } + + meshConf, ok := resp.Entry.(*structs.MeshConfigEntry) + if !ok { + return fmt.Errorf("invalid type for config entry: %T", resp.Entry) + } + snap.MeshGateway.MeshConfig = meshConf snap.MeshGateway.MeshConfigSet = true + // We avoid managing Consul server watches when WAN federation is enabled since it + // always requires server watches. + if s.meta[structs.MetaWANFederationKey] == "1" { + return nil + } + + if meshConf.Peering == nil || !meshConf.Peering.PeerThroughMeshGateways { + snap.MeshGateway.WatchedConsulServers.CancelWatch(structs.ConsulServiceName) + return nil + } + if snap.MeshGateway.WatchedConsulServers.IsWatched(structs.ConsulServiceName) { + return nil + } + + notifyCtx, cancel := context.WithCancel(ctx) + err := s.dataSources.Health.Notify(notifyCtx, &structs.ServiceSpecificRequest{ + Datacenter: s.source.Datacenter, + QueryOptions: structs.QueryOptions{Token: s.token}, + ServiceName: structs.ConsulServiceName, + }, consulServerListWatchID, s.ch) + if err != nil { + cancel() + return fmt.Errorf("failed to watch local consul servers: %w", err) + } + + snap.MeshGateway.WatchedConsulServers.InitWatch(structs.ConsulServiceName, cancel) + default: switch { case strings.HasPrefix(u.CorrelationID, "connect-service:"): diff --git a/agent/proxycfg/snapshot.go b/agent/proxycfg/snapshot.go index 23cb8a9556..130977a8c6 100644 --- a/agent/proxycfg/snapshot.go +++ b/agent/proxycfg/snapshot.go @@ -375,8 +375,11 @@ type configSnapshotMeshGateway struct { // datacenter. FedStateGateways map[string]structs.CheckServiceNodes - // ConsulServers is the list of consul servers in this datacenter. - ConsulServers structs.CheckServiceNodes + // WatchedConsulServers is a map of (structs.ConsulServiceName -> structs.CheckServiceNodes)` + // Mesh gateways can spin up watches for local servers both for + // WAN federation and for peering. This map ensures we only have one + // watch at a time. + WatchedConsulServers watch.Map[string, structs.CheckServiceNodes] // HostnameDatacenters is a map of datacenters to mesh gateway instances with a hostname as the address. // If hostnames are configured they must be provided to Envoy via CDS not EDS. @@ -556,8 +559,8 @@ func (c *configSnapshotMeshGateway) isEmpty() bool { len(c.ServiceResolvers) == 0 && len(c.GatewayGroups) == 0 && len(c.FedStateGateways) == 0 && - len(c.ConsulServers) == 0 && len(c.HostnameDatacenters) == 0 && + c.WatchedConsulServers.Len() == 0 && c.isEmptyPeering() } @@ -690,8 +693,11 @@ func (s *ConfigSnapshot) Valid() bool { s.TerminatingGateway.MeshConfigSet case structs.ServiceKindMeshGateway: - if s.ServiceMeta[structs.MetaWANFederationKey] == "1" { - if len(s.MeshGateway.ConsulServers) == 0 { + if s.MeshGateway.WatchedConsulServers.Len() == 0 { + if s.ServiceMeta[structs.MetaWANFederationKey] == "1" { + return false + } + if cfg := s.MeshConfig(); cfg != nil && cfg.Peering != nil && cfg.Peering.PeerThroughMeshGateways { return false } } diff --git a/agent/proxycfg/state_test.go b/agent/proxycfg/state_test.go index f8cf0834ce..3add369a8e 100644 --- a/agent/proxycfg/state_test.go +++ b/agent/proxycfg/state_test.go @@ -779,6 +779,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { Service: "mesh-gateway", Address: "10.0.1.1", Port: 443, + Meta: map[string]string{ + structs.MetaWANFederationKey: "1", + }, }, sourceDC: "dc1", stages: []verificationStage{ @@ -790,6 +793,7 @@ func TestState_WatchesAndUpdates(t *testing.T) { exportedServiceListWatchID: genVerifyDCSpecificWatch("dc1"), meshConfigEntryID: genVerifyMeshConfigWatch("dc1"), peeringTrustBundlesWatchID: genVerifyTrustBundleListWatchForMeshGateway(""), + consulServerListWatchID: genVerifyServiceSpecificPeeredRequest(structs.ConsulServiceName, "", "dc1", "", false), }, verifySnapshot: func(t testing.TB, snap *ConfigSnapshot) { require.False(t, snap.Valid(), "gateway without root is not valid") @@ -1015,6 +1019,186 @@ func TestState_WatchesAndUpdates(t *testing.T) { }, }, }, + "mesh-gateway-peering-control-plane": { + ns: structs.NodeService{ + Kind: structs.ServiceKindMeshGateway, + ID: "mesh-gateway", + Service: "mesh-gateway", + Address: "10.0.1.1", + Port: 443, + }, + sourceDC: "dc1", + stages: []verificationStage{ + { + requiredWatches: map[string]verifyWatchRequest{ + datacentersWatchID: verifyDatacentersWatch, + serviceListWatchID: genVerifyDCSpecificWatch("dc1"), + rootsWatchID: genVerifyDCSpecificWatch("dc1"), + exportedServiceListWatchID: genVerifyDCSpecificWatch("dc1"), + meshConfigEntryID: genVerifyMeshConfigWatch("dc1"), + peeringTrustBundlesWatchID: genVerifyTrustBundleListWatchForMeshGateway(""), + }, + verifySnapshot: func(t testing.TB, snap *ConfigSnapshot) { + require.False(t, snap.Valid(), "gateway without root is not valid") + }, + }, + { + events: []UpdateEvent{ + rootWatchEvent(), + { + CorrelationID: meshConfigEntryID, + Result: &structs.ConfigEntryResponse{ + Entry: &structs.MeshConfigEntry{ + Peering: &structs.PeeringMeshConfig{ + PeerThroughMeshGateways: true, + }, + }, + }, + }, + { + CorrelationID: exportedServiceListWatchID, + Result: &structs.IndexedExportedServiceList{ + Services: nil, + }, + }, + { + CorrelationID: serviceListWatchID, + Result: &structs.IndexedServiceList{ + Services: structs.ServiceList{}, + }, + }, + { + CorrelationID: peeringTrustBundlesWatchID, + Result: &pbpeering.TrustBundleListByServiceResponse{ + Bundles: nil, + }, + }, + }, + verifySnapshot: func(t testing.TB, snap *ConfigSnapshot) { + require.Equal(t, indexedRoots, snap.Roots) + require.True(t, snap.MeshGateway.WatchedServicesSet) + require.True(t, snap.MeshGateway.PeeringTrustBundlesSet) + require.True(t, snap.MeshGateway.MeshConfigSet) + + require.True(t, snap.Valid(), "gateway without services is valid") + require.True(t, snap.ConnectProxy.isEmpty()) + }, + }, + { + requiredWatches: map[string]verifyWatchRequest{ + consulServerListWatchID: genVerifyServiceSpecificPeeredRequest(structs.ConsulServiceName, "", "dc1", "", false), + }, + events: []UpdateEvent{ + { + CorrelationID: consulServerListWatchID, + Result: &structs.IndexedCheckServiceNodes{ + Nodes: structs.CheckServiceNodes{ + { + Node: &structs.Node{ + Datacenter: "dc1", + Node: "node1", + Address: "127.0.0.1", + }, + Service: &structs.NodeService{ + ID: structs.ConsulServiceID, + Service: structs.ConsulServiceName, + }, + }, + { + Node: &structs.Node{ + Datacenter: "dc1", + Node: "replica1", + Address: "127.0.0.1", + }, + Service: &structs.NodeService{ + ID: structs.ConsulServiceID, + Service: structs.ConsulServiceName, + Meta: map[string]string{"read_replica": "true"}, + }, + }, + }, + }, + Err: nil, + }, + }, + verifySnapshot: func(t testing.TB, snap *ConfigSnapshot) { + require.True(t, snap.Valid()) + + servers, ok := snap.MeshGateway.WatchedConsulServers.Get(structs.ConsulServiceName) + require.True(t, ok) + + expect := structs.CheckServiceNodes{ + { + Node: &structs.Node{ + Datacenter: "dc1", + Node: "node1", + Address: "127.0.0.1", + }, + Service: &structs.NodeService{ + ID: structs.ConsulServiceID, + Service: structs.ConsulServiceName, + }, + }, + { + Node: &structs.Node{ + Datacenter: "dc1", + Node: "replica1", + Address: "127.0.0.1", + }, + Service: &structs.NodeService{ + ID: structs.ConsulServiceID, + Service: structs.ConsulServiceName, + Meta: map[string]string{"read_replica": "true"}, + }, + }, + } + require.Equal(t, expect, servers) + }, + }, + { + events: []UpdateEvent{ + { + CorrelationID: meshConfigEntryID, + Result: &structs.ConfigEntryResponse{ + Entry: &structs.MeshConfigEntry{ + Peering: &structs.PeeringMeshConfig{ + PeerThroughMeshGateways: false, + }, + }, + }, + }, + }, + verifySnapshot: func(t testing.TB, snap *ConfigSnapshot) { + require.True(t, snap.Valid()) + require.NotNil(t, snap.MeshConfig()) + + require.False(t, snap.MeshGateway.WatchedConsulServers.IsWatched(structs.ConsulServiceName)) + servers, ok := snap.MeshGateway.WatchedConsulServers.Get(structs.ConsulServiceName) + require.False(t, ok) + require.Empty(t, servers) + }, + }, + { + events: []UpdateEvent{ + { + CorrelationID: meshConfigEntryID, + Result: &structs.ConfigEntryResponse{ + Entry: nil, + }, + }, + }, + verifySnapshot: func(t testing.TB, snap *ConfigSnapshot) { + require.True(t, snap.Valid()) + require.Nil(t, snap.MeshConfig()) + + require.False(t, snap.MeshGateway.WatchedConsulServers.IsWatched(structs.ConsulServiceName)) + servers, ok := snap.MeshGateway.WatchedConsulServers.Get(structs.ConsulServiceName) + require.False(t, ok) + require.Empty(t, servers) + }, + }, + }, + }, "ingress-gateway": { ns: structs.NodeService{ Kind: structs.ServiceKindIngressGateway, diff --git a/agent/xds/clusters.go b/agent/xds/clusters.go index a425f829ee..2889868bb5 100644 --- a/agent/xds/clusters.go +++ b/agent/xds/clusters.go @@ -386,7 +386,8 @@ func (s *ResourceGenerator) clustersFromSnapshotMeshGateway(cfgSnap *proxycfg.Co } // And for the current datacenter, send all flavors appropriately. - for _, srv := range cfgSnap.MeshGateway.ConsulServers { + servers, _ := cfgSnap.MeshGateway.WatchedConsulServers.Get(structs.ConsulServiceName) + for _, srv := range servers { opts := clusterOpts{ name: cfgSnap.ServerSNIFn(cfgSnap.Datacenter, srv.Node.Node), } diff --git a/agent/xds/endpoints.go b/agent/xds/endpoints.go index b5588ce649..d3083979b8 100644 --- a/agent/xds/endpoints.go +++ b/agent/xds/endpoints.go @@ -249,7 +249,8 @@ func (s *ResourceGenerator) endpointsFromSnapshotMeshGateway(cfgSnap *proxycfg.C cfgSnap.ServerSNIFn != nil { var allServersLbEndpoints []*envoy_endpoint_v3.LbEndpoint - for _, srv := range cfgSnap.MeshGateway.ConsulServers { + servers, _ := cfgSnap.MeshGateway.WatchedConsulServers.Get(structs.ConsulServiceName) + for _, srv := range servers { clusterName := cfgSnap.ServerSNIFn(cfgSnap.Datacenter, srv.Node.Node) _, addr, port := srv.BestAddress(false /*wan*/) diff --git a/agent/xds/listeners.go b/agent/xds/listeners.go index cfea25cbc1..d74d44ab87 100644 --- a/agent/xds/listeners.go +++ b/agent/xds/listeners.go @@ -274,7 +274,7 @@ func (s *ResourceGenerator) listenersFromSnapshotConnectProxy(cfgSnap *proxycfg. return nil } configuredPorts[svcConfig.Destination.Port] = struct{}{} - const name = "~http" //name used for the shared route name + const name = "~http" // name used for the shared route name routeName := clusterNameForDestination(cfgSnap, name, fmt.Sprintf("%d", svcConfig.Destination.Port), svcConfig.NamespaceOrDefault(), svcConfig.PartitionOrDefault()) filterChain, err := s.makeUpstreamFilterChain(filterChainOpts{ routeName: routeName, @@ -1739,7 +1739,8 @@ func (s *ResourceGenerator) makeMeshGatewayListener(name, addr string, port int, } // Wildcard all flavors to each server. - for _, srv := range cfgSnap.MeshGateway.ConsulServers { + servers, _ := cfgSnap.MeshGateway.WatchedConsulServers.Get(structs.ConsulServiceName) + for _, srv := range servers { clusterName := cfgSnap.ServerSNIFn(cfgSnap.Datacenter, srv.Node.Node) filterName := fmt.Sprintf("%s.%s", name, cfgSnap.Datacenter)