diff --git a/agent/proxycfg/snapshot.go b/agent/proxycfg/snapshot.go index a12e1f126f..4345fda453 100644 --- a/agent/proxycfg/snapshot.go +++ b/agent/proxycfg/snapshot.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sort" + "strings" "github.com/mitchellh/copystructure" @@ -38,11 +39,11 @@ type ConfigSnapshotUpstreams struct { WatchedUpstreamEndpoints map[string]map[string]structs.CheckServiceNodes // WatchedGateways is a map of upstream.Identifier() -> (map of - // TargetID -> CancelFunc) in order to cancel watches for mesh gateways + // GatewayKey.String() -> CancelFunc) in order to cancel watches for mesh gateways WatchedGateways map[string]map[string]context.CancelFunc // WatchedGatewayEndpoints is a map of upstream.Identifier() -> (map of - // TargetID -> CheckServiceNodes) and is used to determine the backing + // GatewayKey.String() -> CheckServiceNodes) and is used to determine the backing // endpoints of a mesh gateway. WatchedGatewayEndpoints map[string]map[string]structs.CheckServiceNodes @@ -53,6 +54,27 @@ type ConfigSnapshotUpstreams struct { PassthroughUpstreams map[string]ServicePassthroughAddrs } +type GatewayKey struct { + Datacenter string + Partition string +} + +func (k GatewayKey) String() string { + return k.Partition + "." + k.Datacenter +} + +func (k GatewayKey) IsEmpty() bool { + return k.Partition == "" && k.Datacenter == "" +} + +func gatewayKeyFromString(s string) GatewayKey { + split := strings.Split(s, ".") + return GatewayKey{ + Partition: split[0], + Datacenter: split[1], + } +} + // ServicePassthroughAddrs contains the LAN addrs type ServicePassthroughAddrs struct { // SNI is the Service SNI of the upstream. diff --git a/agent/proxycfg/state.go b/agent/proxycfg/state.go index d393a5844f..2c443f6be6 100644 --- a/agent/proxycfg/state.go +++ b/agent/proxycfg/state.go @@ -3,6 +3,7 @@ package proxycfg import ( "context" "errors" + "fmt" "net" "reflect" "time" @@ -426,3 +427,23 @@ func hostnameEndpoints(logger hclog.Logger, localDC string, nodes structs.CheckS } return resp } + +type gatewayWatchOpts struct { + notifier CacheNotifier + notifyCh chan cache.UpdateEvent + source structs.QuerySource + token string + key GatewayKey + upstreamID string +} + +func watchMeshGateway(ctx context.Context, opts gatewayWatchOpts) error { + return opts.notifier.Notify(ctx, cachetype.InternalServiceDumpName, &structs.ServiceDumpRequest{ + Datacenter: opts.key.Datacenter, + QueryOptions: structs.QueryOptions{Token: opts.token}, + ServiceKind: structs.ServiceKindMeshGateway, + UseServiceKind: true, + Source: opts.source, + EnterpriseMeta: *structs.DefaultEnterpriseMetaInPartition(opts.key.Partition), + }, fmt.Sprintf("mesh-gateway:%s:%s", opts.key.String(), opts.upstreamID), opts.notifyCh) +} diff --git a/agent/proxycfg/state_test.go b/agent/proxycfg/state_test.go index be22dbec87..81671ef3a0 100644 --- a/agent/proxycfg/state_test.go +++ b/agent/proxycfg/state_test.go @@ -649,8 +649,8 @@ func TestState_WatchesAndUpdates(t *testing.T) { "upstream-target:api-failover-remote.default.default.dc2:api-failover-remote?dc=dc2": genVerifyServiceWatch("api-failover-remote", "", "dc2", true), "upstream-target:api-failover-local.default.default.dc2:api-failover-local?dc=dc2": genVerifyServiceWatch("api-failover-local", "", "dc2", true), "upstream-target:api-failover-direct.default.default.dc2:api-failover-direct?dc=dc2": genVerifyServiceWatch("api-failover-direct", "", "dc2", true), - "mesh-gateway:dc2:api-failover-remote?dc=dc2": genVerifyGatewayWatch("dc2"), - "mesh-gateway:dc1:api-failover-local?dc=dc2": genVerifyGatewayWatch("dc1"), + "mesh-gateway:default.dc2:api-failover-remote?dc=dc2": genVerifyGatewayWatch("dc2"), + "mesh-gateway:default.dc1:api-failover-local?dc=dc2": genVerifyGatewayWatch("dc1"), }, verifySnapshot: func(t testing.TB, snap *ConfigSnapshot) { require.True(t, snap.Valid()) @@ -673,7 +673,7 @@ func TestState_WatchesAndUpdates(t *testing.T) { } if meshGatewayProxyConfigValue == structs.MeshGatewayModeLocal { - stage1.requiredWatches["mesh-gateway:dc1:api-dc2"] = genVerifyGatewayWatch("dc1") + stage1.requiredWatches["mesh-gateway:default.dc1:api-dc2"] = genVerifyGatewayWatch("dc1") } return testCase{ diff --git a/agent/proxycfg/testing.go b/agent/proxycfg/testing.go index f156179afe..3fcf4c2685 100644 --- a/agent/proxycfg/testing.go +++ b/agent/proxycfg/testing.go @@ -1429,7 +1429,7 @@ func setupTestVariationConfigEntriesAndSnapshot( TestUpstreamNodesDC2(t) snap.WatchedGatewayEndpoints = map[string]map[string]structs.CheckServiceNodes{ "db": { - "dc2": TestGatewayNodesDC2(t), + "default.dc2": TestGatewayNodesDC2(t), }, } case "failover-through-double-remote-gateway-triggered": @@ -1442,8 +1442,8 @@ func setupTestVariationConfigEntriesAndSnapshot( snap.WatchedUpstreamEndpoints["db"]["db.default.default.dc3"] = TestUpstreamNodesDC2(t) snap.WatchedGatewayEndpoints = map[string]map[string]structs.CheckServiceNodes{ "db": { - "dc2": TestGatewayNodesDC2(t), - "dc3": TestGatewayNodesDC3(t), + "default.dc2": TestGatewayNodesDC2(t), + "default.dc3": TestGatewayNodesDC3(t), }, } case "failover-through-local-gateway-triggered": @@ -1455,7 +1455,7 @@ func setupTestVariationConfigEntriesAndSnapshot( TestUpstreamNodesDC2(t) snap.WatchedGatewayEndpoints = map[string]map[string]structs.CheckServiceNodes{ "db": { - "dc1": TestGatewayNodesDC1(t), + "default.dc1": TestGatewayNodesDC1(t), }, } case "failover-through-double-local-gateway-triggered": @@ -1468,7 +1468,7 @@ func setupTestVariationConfigEntriesAndSnapshot( snap.WatchedUpstreamEndpoints["db"]["db.default.default.dc3"] = TestUpstreamNodesDC2(t) snap.WatchedGatewayEndpoints = map[string]map[string]structs.CheckServiceNodes{ "db": { - "dc1": TestGatewayNodesDC1(t), + "default.dc1": TestGatewayNodesDC1(t), }, } case "splitter-with-resolver-redirect-multidc": @@ -1737,9 +1737,10 @@ func testConfigSnapshotIngressGateway( {protocol, 9191}: { { // We rely on this one having default type in a few tests... - DestinationName: "db", - LocalBindPort: 9191, - LocalBindAddress: "2.3.4.5", + DestinationName: "db", + DestinationPartition: "default", + LocalBindPort: 9191, + LocalBindAddress: "2.3.4.5", }, }, }, diff --git a/agent/proxycfg/upstreams.go b/agent/proxycfg/upstreams.go index e11c1a48ba..1f5060b158 100644 --- a/agent/proxycfg/upstreams.go +++ b/agent/proxycfg/upstreams.go @@ -118,14 +118,15 @@ func (s *handlerUpstreams) handleUpdateUpstreams(ctx context.Context, u cache.Up return fmt.Errorf("invalid type for response: %T", u.Result) } correlationID := strings.TrimPrefix(u.CorrelationID, "mesh-gateway:") - dc, svc, ok := removeColonPrefix(correlationID) + key, svc, ok := removeColonPrefix(correlationID) if !ok { return fmt.Errorf("invalid correlation id %q", u.CorrelationID) } if _, ok = upstreamsSnapshot.WatchedGatewayEndpoints[svc]; !ok { upstreamsSnapshot.WatchedGatewayEndpoints[svc] = make(map[string]structs.CheckServiceNodes) } - upstreamsSnapshot.WatchedGatewayEndpoints[svc][dc] = resp.Nodes + upstreamsSnapshot.WatchedGatewayEndpoints[svc][key] = resp.Nodes + default: return fmt.Errorf("unknown correlation ID: %s", u.CorrelationID) } @@ -207,11 +208,22 @@ func (s *handlerUpstreams) resetWatchesFromChain( // We'll get endpoints from the gateway query, but the health still has // to come from the backing service query. + var gk GatewayKey + switch target.MeshGateway.Mode { case structs.MeshGatewayModeRemote: - needGateways[target.Datacenter] = struct{}{} + gk = GatewayKey{ + Partition: target.Partition, + Datacenter: target.Datacenter, + } case structs.MeshGatewayModeLocal: - needGateways[s.source.Datacenter] = struct{}{} + gk = GatewayKey{ + Partition: s.source.NodePartitionOrDefault(), + Datacenter: s.source.Datacenter, + } + } + if s.source.Datacenter != target.Datacenter || s.proxyID.PartitionOrDefault() != target.Partition { + needGateways[gk.String()] = struct{}{} } } @@ -240,38 +252,51 @@ func (s *handlerUpstreams) resetWatchesFromChain( } } - for dc := range needGateways { - if _, ok := snap.WatchedGateways[id][dc]; ok { + for key := range needGateways { + if _, ok := snap.WatchedGateways[id][key]; ok { continue } + gwKey := gatewayKeyFromString(key) - s.logger.Trace("initializing watch of mesh gateway in datacenter", + s.logger.Trace("initializing watch of mesh gateway", "upstream", id, "chain", chain.ServiceName, - "datacenter", dc, + "datacenter", gwKey.Datacenter, + "partition", gwKey.Partition, ) ctx, cancel := context.WithCancel(ctx) - err := s.watchMeshGateway(ctx, dc, id) + opts := gatewayWatchOpts{ + notifier: s.cache, + notifyCh: s.ch, + source: *s.source, + token: s.token, + key: gwKey, + upstreamID: id, + } + err := watchMeshGateway(ctx, opts) if err != nil { cancel() return err } - snap.WatchedGateways[id][dc] = cancel + snap.WatchedGateways[id][key] = cancel } - for dc, cancelFn := range snap.WatchedGateways[id] { - if _, ok := needGateways[dc]; ok { + for key, cancelFn := range snap.WatchedGateways[id] { + if _, ok := needGateways[key]; ok { continue } - s.logger.Trace("stopping watch of mesh gateway in datacenter", + gwKey := gatewayKeyFromString(key) + + s.logger.Trace("stopping watch of mesh gateway", "upstream", id, "chain", chain.ServiceName, - "datacenter", dc, + "datacenter", gwKey.Datacenter, + "partition", gwKey.Partition, ) - delete(snap.WatchedGateways[id], dc) - delete(snap.WatchedGatewayEndpoints[id], dc) + delete(snap.WatchedGateways[id], key) + delete(snap.WatchedGatewayEndpoints[id], key) cancelFn() } @@ -287,17 +312,6 @@ type targetWatchOpts struct { entMeta *structs.EnterpriseMeta } -func (s *handlerUpstreams) watchMeshGateway(ctx context.Context, dc string, upstreamID string) error { - return s.cache.Notify(ctx, cachetype.InternalServiceDumpName, &structs.ServiceDumpRequest{ - Datacenter: dc, - QueryOptions: structs.QueryOptions{Token: s.token}, - ServiceKind: structs.ServiceKindMeshGateway, - UseServiceKind: true, - Source: *s.source, - EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), - }, "mesh-gateway:"+dc+":"+upstreamID, s.ch) -} - func (s *handlerUpstreams) watchUpstreamTarget(ctx context.Context, snap *ConfigSnapshotUpstreams, opts targetWatchOpts) error { s.logger.Trace("initializing watch of target", "upstream", opts.upstreamID,