diff --git a/agent/proxycfg/mesh_gateway.go b/agent/proxycfg/mesh_gateway.go index 63747a9404..dbfddc64c2 100644 --- a/agent/proxycfg/mesh_gateway.go +++ b/agent/proxycfg/mesh_gateway.go @@ -567,6 +567,7 @@ func (s *handlerMeshGateway) handleUpdate(ctx context.Context, u UpdateEvent, sn Request: &pbpeering.PeeringListRequest{ Partition: acl.WildcardPartitionName, }, + QueryOptions: structs.QueryOptions{Token: s.token}, }, peerServersWatchID, s.ch) if err != nil { meshLogger.Error("failed to register watch for peering list", "error", err) diff --git a/agent/proxycfg/state_test.go b/agent/proxycfg/state_test.go index 2df9fa861e..90eebd968a 100644 --- a/agent/proxycfg/state_test.go +++ b/agent/proxycfg/state_test.go @@ -194,6 +194,7 @@ func genVerifyDCSpecificWatch(expectedDatacenter string) verifyWatchRequest { return func(t testing.TB, request any) { reqReal, ok := request.(*structs.DCSpecificRequest) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, expectedDatacenter, reqReal.Datacenter) } } @@ -207,6 +208,7 @@ func genVerifyTrustBundleReadWatch(peer string) verifyWatchRequest { return func(t testing.TB, request any) { reqReal, ok := request.(*cachetype.TrustBundleReadRequest) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, peer, reqReal.Request.Name) } } @@ -215,6 +217,7 @@ func genVerifyLeafWatchWithDNSSANs(expectedService string, expectedDatacenter st return func(t testing.TB, request any) { reqReal, ok := request.(*cachetype.ConnectCALeafRequest) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, expectedDatacenter, reqReal.Datacenter) require.Equal(t, expectedService, reqReal.Service) require.ElementsMatch(t, expectedDNSSANs, reqReal.DNSSAN) @@ -229,6 +232,7 @@ func genVerifyTrustBundleListWatch(service string) verifyWatchRequest { return func(t testing.TB, request any) { reqReal, ok := request.(*cachetype.TrustBundleListRequest) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, service, reqReal.Request.ServiceName) } } @@ -239,6 +243,7 @@ func genVerifyTrustBundleListWatchForMeshGateway(partition string) verifyWatchRe require.True(t, ok) require.Equal(t, string(structs.ServiceKindMeshGateway), reqReal.Request.Kind) require.True(t, acl.EqualPartitions(partition, reqReal.Request.Partition), "%q != %q", partition, reqReal.Request.Partition) + require.Equal(t, aclToken, reqReal.Token) require.NotEmpty(t, reqReal.Request.ServiceName) } } @@ -247,6 +252,7 @@ func genVerifyPeeringListWatchForMeshGateway() verifyWatchRequest { return func(t testing.TB, request any) { reqReal, ok := request.(*cachetype.PeeringListRequest) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, acl.WildcardPartitionName, reqReal.Request.Partition) } } @@ -255,6 +261,7 @@ func genVerifyResolverWatch(expectedService, expectedDatacenter, expectedKind st return func(t testing.TB, request any) { reqReal, ok := request.(*structs.ConfigEntryQuery) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, expectedDatacenter, reqReal.Datacenter) require.Equal(t, expectedService, reqReal.Name) require.Equal(t, expectedKind, reqReal.Kind) @@ -265,6 +272,7 @@ func genVerifyResolvedConfigWatch(expectedService string, expectedDatacenter str return func(t testing.TB, request any) { reqReal, ok := request.(*structs.ServiceConfigRequest) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, expectedDatacenter, reqReal.Datacenter) require.Equal(t, expectedService, reqReal.Name) } @@ -274,6 +282,7 @@ func genVerifyIntentionWatch(expectedService string, expectedDatacenter string) return func(t testing.TB, request any) { reqReal, ok := request.(*structs.ServiceSpecificRequest) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, expectedDatacenter, reqReal.Datacenter) require.Equal(t, expectedService, reqReal.ServiceName) } @@ -283,6 +292,7 @@ func genVerifyPreparedQueryWatch(expectedName string, expectedDatacenter string) return func(t testing.TB, request any) { reqReal, ok := request.(*structs.PreparedQueryExecuteRequest) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, expectedDatacenter, reqReal.Datacenter) require.Equal(t, expectedName, reqReal.QueryIDOrName) require.Equal(t, true, reqReal.Connect) @@ -293,6 +303,7 @@ func genVerifyDiscoveryChainWatch(expected *structs.DiscoveryChainRequest) verif return func(t testing.TB, request any) { reqReal, ok := request.(*structs.DiscoveryChainRequest) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, expected, reqReal) } } @@ -301,6 +312,7 @@ func genVerifyMeshConfigWatch(expectedDatacenter string) verifyWatchRequest { return func(t testing.TB, request any) { reqReal, ok := request.(*structs.ConfigEntryQuery) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, expectedDatacenter, reqReal.Datacenter) require.Equal(t, structs.MeshConfigMesh, reqReal.Name) require.Equal(t, structs.MeshConfig, reqReal.Kind) @@ -311,6 +323,7 @@ func genVerifyGatewayWatch(expectedDatacenter string) verifyWatchRequest { return func(t testing.TB, request any) { reqReal, ok := request.(*structs.ServiceDumpRequest) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, expectedDatacenter, reqReal.Datacenter) require.True(t, reqReal.UseServiceKind) require.Equal(t, structs.ServiceKindMeshGateway, reqReal.ServiceKind) @@ -326,6 +339,7 @@ func genVerifyServiceSpecificPeeredRequest(expectedService, expectedFilter, expe return func(t testing.TB, request any) { reqReal, ok := request.(*structs.ServiceSpecificRequest) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, expectedDatacenter, reqReal.Datacenter) require.Equal(t, expectedPeer, reqReal.PeerName) require.Equal(t, expectedService, reqReal.ServiceName) @@ -338,6 +352,7 @@ func genVerifyPartitionSpecificRequest(expectedPartition, expectedDatacenter str return func(t testing.TB, request any) { reqReal, ok := request.(*structs.PartitionSpecificRequest) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, expectedDatacenter, reqReal.Datacenter) require.Equal(t, expectedPartition, reqReal.PartitionOrDefault()) } @@ -351,6 +366,7 @@ func genVerifyConfigEntryWatch(expectedKind, expectedName, expectedDatacenter st return func(t testing.TB, request any) { reqReal, ok := request.(*structs.ConfigEntryQuery) require.True(t, ok) + require.Equal(t, aclToken, reqReal.Token) require.Equal(t, expectedKind, reqReal.Kind) require.Equal(t, expectedName, reqReal.Name) require.Equal(t, expectedDatacenter, reqReal.Datacenter) @@ -536,6 +552,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { OverrideMeshGateway: structs.MeshGatewayConfig{ Mode: meshGatewayProxyConfigValue, }, + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), fmt.Sprintf("discovery-chain:%s-failover-remote?dc=dc2", apiUID.String()): genVerifyDiscoveryChainWatch(&structs.DiscoveryChainRequest{ Name: "api-failover-remote", @@ -546,6 +565,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { OverrideMeshGateway: structs.MeshGatewayConfig{ Mode: structs.MeshGatewayModeRemote, }, + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), fmt.Sprintf("discovery-chain:%s-failover-local?dc=dc2", apiUID.String()): genVerifyDiscoveryChainWatch(&structs.DiscoveryChainRequest{ Name: "api-failover-local", @@ -556,6 +578,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { OverrideMeshGateway: structs.MeshGatewayConfig{ Mode: structs.MeshGatewayModeLocal, }, + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), fmt.Sprintf("discovery-chain:%s-failover-direct?dc=dc2", apiUID.String()): genVerifyDiscoveryChainWatch(&structs.DiscoveryChainRequest{ Name: "api-failover-direct", @@ -566,6 +591,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { OverrideMeshGateway: structs.MeshGatewayConfig{ Mode: structs.MeshGatewayModeNone, }, + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), fmt.Sprintf("discovery-chain:%s-failover-to-peer", apiUID.String()): genVerifyDiscoveryChainWatch(&structs.DiscoveryChainRequest{ Name: "api-failover-to-peer", @@ -576,6 +604,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { OverrideMeshGateway: structs.MeshGatewayConfig{ Mode: meshGatewayProxyConfigValue, }, + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), fmt.Sprintf("discovery-chain:%s-dc2", apiUID.String()): genVerifyDiscoveryChainWatch(&structs.DiscoveryChainRequest{ Name: "api-dc2", @@ -586,6 +617,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { OverrideMeshGateway: structs.MeshGatewayConfig{ Mode: meshGatewayProxyConfigValue, }, + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), "upstream:" + pqUID.String(): genVerifyPreparedQueryWatch("query", "dc1"), rootsWatchID: genVerifyDCSpecificWatch("dc1"), @@ -1421,6 +1455,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { EvaluateInNamespace: "default", EvaluateInPartition: "default", Datacenter: "dc1", + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), }, events: []UpdateEvent{ @@ -2275,6 +2312,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { Datacenter: "dc1", OverrideConnectTimeout: 6 * time.Second, OverrideMeshGateway: structs.MeshGatewayConfig{Mode: structs.MeshGatewayModeRemote}, + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), }, events: []UpdateEvent{ @@ -2444,6 +2484,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { Datacenter: "dc1", OverrideConnectTimeout: 6 * time.Second, OverrideMeshGateway: structs.MeshGatewayConfig{Mode: structs.MeshGatewayModeRemote}, + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), }, events: []UpdateEvent{ @@ -2901,6 +2944,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { EvaluateInPartition: "default", Datacenter: "dc1", OverrideMeshGateway: structs.MeshGatewayConfig{Mode: structs.MeshGatewayModeLocal}, + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), rootsWatchID: genVerifyDCSpecificWatch("dc1"), leafWatchID: genVerifyLeafWatch("api", "dc1"), @@ -2966,6 +3012,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { EvaluateInPartition: "default", Datacenter: "dc1", OverrideMeshGateway: structs.MeshGatewayConfig{Mode: structs.MeshGatewayModeLocal}, + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), }, events: []UpdateEvent{ @@ -3001,6 +3050,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { EvaluateInPartition: "default", Datacenter: "dc1", OverrideMeshGateway: structs.MeshGatewayConfig{Mode: structs.MeshGatewayModeLocal}, + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), rootsWatchID: genVerifyDCSpecificWatch("dc1"), leafWatchID: genVerifyLeafWatch("api", "dc1"), @@ -3318,6 +3370,9 @@ func TestState_WatchesAndUpdates(t *testing.T) { EvaluateInNamespace: "default", EvaluateInPartition: "default", Datacenter: "dc1", + QueryOptions: structs.QueryOptions{ + Token: aclToken, + }, }), rootsWatchID: genVerifyDCSpecificWatch("dc1"), leafWatchID: genVerifyLeafWatch("web", "dc1"), @@ -3464,7 +3519,7 @@ func TestState_WatchesAndUpdates(t *testing.T) { } wr := recordWatches(&sc) - state, err := newState(proxyID, &tc.ns, testSource, "", sc, rate.NewLimiter(rate.Inf, 0)) + state, err := newState(proxyID, &tc.ns, testSource, aclToken, sc, rate.NewLimiter(rate.Inf, 0)) // verify building the initial state worked require.NoError(t, err) @@ -3637,3 +3692,5 @@ func Test_hostnameEndpoints(t *testing.T) { }) } } + +const aclToken = "foo"