diff --git a/agent/consul/discoverychain/gateway.go b/agent/consul/discoverychain/gateway.go index 9b8c4eb73c..c60c77b028 100644 --- a/agent/consul/discoverychain/gateway.go +++ b/agent/consul/discoverychain/gateway.go @@ -198,9 +198,14 @@ func (l *GatewayChainSynthesizer) consolidateHTTPRoutes() []structs.HTTPRouteCon return consolidateHTTPRoutes(l.matchesByHostname, l.suffix, l.gateway) } -// ReformatHTTPRoute takes in an HTTPRoute and reformats it to match the discovery chains generated by the gateway chain synthesizer -func ReformatHTTPRoute(route *structs.HTTPRouteConfigEntry, listener *structs.APIGatewayListener, gateway *structs.APIGatewayConfigEntry) []structs.HTTPRouteConfigEntry { - matches := initHostMatches(listener.GetHostname(), route, map[string][]hostnameMatch{}) +// ConsolidateHTTPRoutes takes in one or more HTTPRoutes and consolidates them down to the minimum +// set of HTTPRoutes that can represent the same set of rules. This should result in approx. one +// HTTPRoute per hostname. +func ConsolidateHTTPRoutes(gateway *structs.APIGatewayConfigEntry, listener *structs.APIGatewayListener, routes ...*structs.HTTPRouteConfigEntry) []structs.HTTPRouteConfigEntry { + matches := map[string][]hostnameMatch{} + for _, route := range routes { + matches = initHostMatches(listener.GetHostname(), route, matches) + } return consolidateHTTPRoutes(matches, listener.Name, gateway) } diff --git a/agent/xds/jwt_authn_ce.go b/agent/xds/jwt_authn_ce.go index 777352c806..f8cf52957d 100644 --- a/agent/xds/jwt_authn_ce.go +++ b/agent/xds/jwt_authn_ce.go @@ -14,7 +14,7 @@ import ( type GatewayAuthFilterBuilder struct { listener structs.APIGatewayListener - route *structs.HTTPRouteConfigEntry + routes []*structs.HTTPRouteConfigEntry providers map[string]*structs.JWTProviderConfigEntry envoyProviders map[string]*envoy_http_jwt_authn_v3.JwtProvider } diff --git a/agent/xds/listeners_apigateway.go b/agent/xds/listeners_apigateway.go index 0beca6fe6d..a4611895e2 100644 --- a/agent/xds/listeners_apigateway.go +++ b/agent/xds/listeners_apigateway.go @@ -6,12 +6,15 @@ package xds import ( "fmt" + "golang.org/x/exp/maps" + envoy_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" envoy_http_jwt_authn_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/jwt_authn/v3" envoy_http_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" envoy_tls_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + "github.com/hashicorp/consul/agent/consul/discoverychain" "github.com/hashicorp/consul/agent/xds/naming" "google.golang.org/protobuf/proto" @@ -107,15 +110,21 @@ func (s *ResourceGenerator) makeAPIGatewayListeners(address string, cfgSnap *pro if isAPIGatewayWithTLS { // construct SNI filter chains - l.FilterChains, err = makeInlineOverrideFilterChains(cfgSnap, cfgSnap.APIGateway.TLSConfig, listenerKey.Protocol, listenerFilterOpts{ - useRDS: useRDS, - protocol: listenerKey.Protocol, - routeName: listenerKey.RouteName(), - cluster: clusterName, - statPrefix: "ingress_upstream_", - accessLogs: &cfgSnap.Proxy.AccessLogs, - logger: s.Logger, - }, certs) + l.FilterChains, err = makeInlineOverrideFilterChains( + cfgSnap, + cfgSnap.APIGateway.TLSConfig, + listenerKey.Protocol, + listenerFilterOpts{ + useRDS: useRDS, + protocol: listenerKey.Protocol, + routeName: listenerKey.RouteName(), + cluster: clusterName, + statPrefix: "ingress_upstream_", + accessLogs: &cfgSnap.Proxy.AccessLogs, + logger: s.Logger, + }, + certs, + ) if err != nil { return nil, err } @@ -141,36 +150,47 @@ func (s *ResourceGenerator) makeAPIGatewayListeners(address string, cfgSnap *pro } listener := makeListener(listenerOpts) - route, _ := cfgSnap.APIGateway.HTTPRoutes.Get(readyListener.routeReference) - foundJWT := false - if listenerCfg.Override != nil && listenerCfg.Override.JWT != nil { - foundJWT = true + routes := make([]*structs.HTTPRouteConfigEntry, 0, len(readyListener.routeReferences)) + for _, routeRef := range maps.Keys(readyListener.routeReferences) { + route, _ := cfgSnap.APIGateway.HTTPRoutes.Get(routeRef) + routes = append(routes, route) } + consolidatedRoutes := discoverychain.ConsolidateHTTPRoutes(cfgSnap.APIGateway.GatewayConfig, &readyListener.listenerCfg, routes...) + routesWithJWT := []*structs.HTTPRouteConfigEntry{} + for _, routeCfgEntry := range consolidatedRoutes { + routeCfgEntry := routeCfgEntry + route := &routeCfgEntry - if !foundJWT && listenerCfg.Default != nil && listenerCfg.Default.JWT != nil { - foundJWT = true - } + if listenerCfg.Override != nil && listenerCfg.Override.JWT != nil { + routesWithJWT = append(routesWithJWT, route) + continue + } + + if listenerCfg.Default != nil && listenerCfg.Default.JWT != nil { + routesWithJWT = append(routesWithJWT, route) + continue + } - if !foundJWT { for _, rule := range route.Rules { if rule.Filters.JWT != nil { - foundJWT = true - break + routesWithJWT = append(routesWithJWT, route) + continue } for _, svc := range rule.Services { if svc.Filters.JWT != nil { - foundJWT = true - break + routesWithJWT = append(routesWithJWT, route) + continue } } } + } var authFilters []*envoy_http_v3.HttpFilter - if foundJWT { + if len(routesWithJWT) > 0 { builder := &GatewayAuthFilterBuilder{ listener: listenerCfg, - route: route, + routes: routesWithJWT, providers: cfgSnap.JWTProviders, envoyProviders: make(map[string]*envoy_http_jwt_authn_v3.JwtProvider, len(cfgSnap.JWTProviders)), } @@ -179,6 +199,7 @@ func (s *ResourceGenerator) makeAPIGatewayListeners(address string, cfgSnap *pro return nil, err } } + filterOpts := listenerFilterOpts{ useRDS: true, protocol: listenerKey.Protocol, @@ -246,7 +267,7 @@ type readyListener struct { listenerKey proxycfg.APIGatewayListenerKey listenerCfg structs.APIGatewayListener boundListenerCfg structs.BoundAPIGatewayListener - routeReference structs.ResourceReference + routeReferences map[structs.ResourceReference]struct{} upstreams []structs.Upstream } @@ -285,10 +306,11 @@ func getReadyListeners(cfgSnap *proxycfg.ConfigSnapshot) map[string]readyListene r = readyListener{ listenerKey: listenerKey, listenerCfg: l, + routeReferences: map[structs.ResourceReference]struct{}{}, boundListenerCfg: boundListener, - routeReference: routeRef, } } + r.routeReferences[routeRef] = struct{}{} r.upstreams = append(r.upstreams, upstream) ready[routeKey] = r } @@ -297,7 +319,10 @@ func getReadyListeners(cfgSnap *proxycfg.ConfigSnapshot) map[string]readyListene return ready } -func makeDownstreamTLSContextFromSnapshotAPIListenerConfig(cfgSnap *proxycfg.ConfigSnapshot, listenerCfg structs.APIGatewayListener) (*envoy_tls_v3.DownstreamTlsContext, error) { +func makeDownstreamTLSContextFromSnapshotAPIListenerConfig( + cfgSnap *proxycfg.ConfigSnapshot, + listenerCfg structs.APIGatewayListener, +) (*envoy_tls_v3.DownstreamTlsContext, error) { var downstreamContext *envoy_tls_v3.DownstreamTlsContext tlsContext, err := makeCommonTLSContextFromSnapshotAPIGatewayListenerConfig(cfgSnap, listenerCfg) @@ -318,7 +343,10 @@ func makeDownstreamTLSContextFromSnapshotAPIListenerConfig(cfgSnap *proxycfg.Con return downstreamContext, nil } -func makeCommonTLSContextFromSnapshotAPIGatewayListenerConfig(cfgSnap *proxycfg.ConfigSnapshot, listenerCfg structs.APIGatewayListener) (*envoy_tls_v3.CommonTlsContext, error) { +func makeCommonTLSContextFromSnapshotAPIGatewayListenerConfig( + cfgSnap *proxycfg.ConfigSnapshot, + listenerCfg structs.APIGatewayListener, +) (*envoy_tls_v3.CommonTlsContext, error) { var tlsContext *envoy_tls_v3.CommonTlsContext // API Gateway TLS config is per listener diff --git a/agent/xds/routes.go b/agent/xds/routes.go index 8c1ed9d4d2..bed3c02366 100644 --- a/agent/xds/routes.go +++ b/agent/xds/routes.go @@ -14,6 +14,8 @@ import ( envoy_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" envoy_matcher_v3 "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + "golang.org/x/exp/maps" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/wrapperspb" @@ -433,65 +435,57 @@ func (s *ResourceGenerator) routesForIngressGateway(cfgSnap *proxycfg.ConfigSnap func (s *ResourceGenerator) routesForAPIGateway(cfgSnap *proxycfg.ConfigSnapshot) ([]proto.Message, error) { var result []proto.Message - readyUpstreamsList := getReadyListeners(cfgSnap) + readyListeners := getReadyListeners(cfgSnap) - for _, readyUpstreams := range readyUpstreamsList { - readyUpstreams := readyUpstreams - listenerCfg := readyUpstreams.listenerCfg + for _, readyListener := range readyListeners { // Do not create any route configuration for TCP listeners - if listenerCfg.Protocol != structs.ListenerProtocolHTTP { + if readyListener.listenerCfg.Protocol != structs.ListenerProtocolHTTP { continue } - routeRef := readyUpstreams.routeReference - listenerKey := readyUpstreams.listenerKey - - defaultRoute := &envoy_route_v3.RouteConfiguration{ - Name: listenerKey.RouteName(), + listenerRoute := &envoy_route_v3.RouteConfiguration{ + Name: readyListener.listenerKey.RouteName(), // ValidateClusters defaults to true when defined statically and false // when done via RDS. Re-set the reasonable value of true to prevent // null-routing traffic. ValidateClusters: response.MakeBoolValue(true), } - route, ok := cfgSnap.APIGateway.HTTPRoutes.Get(routeRef) - if !ok { - return nil, fmt.Errorf("missing route for route reference %s:%s", routeRef.Name, routeRef.Kind) + // Consolidate all routes for this listener into the minimum possible set based on hostname matching. + allRoutesForListener := []*structs.HTTPRouteConfigEntry{} + for _, routeRef := range maps.Keys(readyListener.routeReferences) { + route, ok := cfgSnap.APIGateway.HTTPRoutes.Get(routeRef) + if !ok { + return nil, fmt.Errorf("missing route for route routeRef %s:%s", routeRef.Name, routeRef.Kind) + } + allRoutesForListener = append(allRoutesForListener, route) } + consolidatedRoutes := discoverychain.ConsolidateHTTPRoutes(cfgSnap.APIGateway.GatewayConfig, &readyListener.listenerCfg, allRoutesForListener...) - // Reformat the route here since discovery chains were indexed earlier using the - // specific naming convention in discoverychain.consolidateHTTPRoutes. If we don't - // convert our route to use the same naming convention, we won't find any chains below. - reformatedRoutes := discoverychain.ReformatHTTPRoute(route, &listenerCfg, cfgSnap.APIGateway.GatewayConfig) - filterBuilder := perRouteFilterBuilder{providerMap: cfgSnap.JWTProviders, listener: &listenerCfg, route: route} - for _, reformatedRoute := range reformatedRoutes { - reformatedRoute := reformatedRoute - - upstream := buildHTTPRouteUpstream(reformatedRoute, listenerCfg) + // Produce one virtual host per hostname. If no hostname is specified for a set of + // Gateway + HTTPRoutes, then the virtual host will be "*". + for _, consolidatedRoute := range consolidatedRoutes { + upstream := buildHTTPRouteUpstream(consolidatedRoute, readyListener.listenerCfg) uid := proxycfg.NewUpstreamID(&upstream) - chain := cfgSnap.APIGateway.DiscoveryChain[uid] if chain == nil { - // Note that if we continue here we must also do this in the cluster generation s.Logger.Debug("Discovery chain not found for flattened route", "discovery chain ID", uid) continue } - domains := generateUpstreamAPIsDomains(listenerKey, upstream, reformatedRoute.Hostnames) + domains := generateUpstreamAPIsDomains(readyListener.listenerKey, upstream, consolidatedRoute.Hostnames) + filterBuilder := perRouteFilterBuilder{providerMap: cfgSnap.JWTProviders, listener: &readyListener.listenerCfg, route: &consolidatedRoute} virtualHost, err := s.makeUpstreamRouteForDiscoveryChain(cfgSnap, uid, chain, domains, false, filterBuilder) if err != nil { return nil, err } - if virtualHost == nil { - continue - } - defaultRoute.VirtualHosts = append(defaultRoute.VirtualHosts, virtualHost) + listenerRoute.VirtualHosts = append(listenerRoute.VirtualHosts, virtualHost) } - if len(defaultRoute.VirtualHosts) > 0 { - result = append(result, defaultRoute) + if len(listenerRoute.VirtualHosts) > 0 { + result = append(result, listenerRoute) } }