[NET-5457] Fix CE code for jwt multiple virtual hosts bug (#19123)

* Fix CE code for jwt multiple virtual hosts bug

* Fix struct definition

* fix bug with always appending route to jwt config

* Update comment to be correct

* Update comment
This commit is contained in:
John Maguire 2023-10-10 16:25:36 -04:00 committed by GitHub
parent 830c4ea81c
commit 8bebfc147d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 62 deletions

View File

@ -198,9 +198,14 @@ func (l *GatewayChainSynthesizer) consolidateHTTPRoutes() []structs.HTTPRouteCon
return consolidateHTTPRoutes(l.matchesByHostname, l.suffix, l.gateway) return consolidateHTTPRoutes(l.matchesByHostname, l.suffix, l.gateway)
} }
// ReformatHTTPRoute takes in an HTTPRoute and reformats it to match the discovery chains generated by the gateway chain synthesizer // ConsolidateHTTPRoutes takes in one or more HTTPRoutes and consolidates them down to the minimum
func ReformatHTTPRoute(route *structs.HTTPRouteConfigEntry, listener *structs.APIGatewayListener, gateway *structs.APIGatewayConfigEntry) []structs.HTTPRouteConfigEntry { // set of HTTPRoutes that can represent the same set of rules. This should result in approx. one
matches := initHostMatches(listener.GetHostname(), route, map[string][]hostnameMatch{}) // HTTPRoute per hostname.
func ConsolidateHTTPRoutes(gateway *structs.APIGatewayConfigEntry, listener *structs.APIGatewayListener, routes ...*structs.HTTPRouteConfigEntry) []structs.HTTPRouteConfigEntry {
matches := map[string][]hostnameMatch{}
for _, route := range routes {
matches = initHostMatches(listener.GetHostname(), route, matches)
}
return consolidateHTTPRoutes(matches, listener.Name, gateway) return consolidateHTTPRoutes(matches, listener.Name, gateway)
} }

View File

@ -14,7 +14,7 @@ import (
type GatewayAuthFilterBuilder struct { type GatewayAuthFilterBuilder struct {
listener structs.APIGatewayListener listener structs.APIGatewayListener
route *structs.HTTPRouteConfigEntry routes []*structs.HTTPRouteConfigEntry
providers map[string]*structs.JWTProviderConfigEntry providers map[string]*structs.JWTProviderConfigEntry
envoyProviders map[string]*envoy_http_jwt_authn_v3.JwtProvider envoyProviders map[string]*envoy_http_jwt_authn_v3.JwtProvider
} }

View File

@ -6,12 +6,15 @@ package xds
import ( import (
"fmt" "fmt"
"golang.org/x/exp/maps"
envoy_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" envoy_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
envoy_http_jwt_authn_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/jwt_authn/v3" envoy_http_jwt_authn_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/jwt_authn/v3"
envoy_http_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" envoy_http_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3"
envoy_tls_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" envoy_tls_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3"
"github.com/hashicorp/consul/agent/consul/discoverychain"
"github.com/hashicorp/consul/agent/xds/naming" "github.com/hashicorp/consul/agent/xds/naming"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -107,15 +110,21 @@ func (s *ResourceGenerator) makeAPIGatewayListeners(address string, cfgSnap *pro
if isAPIGatewayWithTLS { if isAPIGatewayWithTLS {
// construct SNI filter chains // construct SNI filter chains
l.FilterChains, err = makeInlineOverrideFilterChains(cfgSnap, cfgSnap.APIGateway.TLSConfig, listenerKey.Protocol, listenerFilterOpts{ l.FilterChains, err = makeInlineOverrideFilterChains(
useRDS: useRDS, cfgSnap,
protocol: listenerKey.Protocol, cfgSnap.APIGateway.TLSConfig,
routeName: listenerKey.RouteName(), listenerKey.Protocol,
cluster: clusterName, listenerFilterOpts{
statPrefix: "ingress_upstream_", useRDS: useRDS,
accessLogs: &cfgSnap.Proxy.AccessLogs, protocol: listenerKey.Protocol,
logger: s.Logger, routeName: listenerKey.RouteName(),
}, certs) cluster: clusterName,
statPrefix: "ingress_upstream_",
accessLogs: &cfgSnap.Proxy.AccessLogs,
logger: s.Logger,
},
certs,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -141,36 +150,47 @@ func (s *ResourceGenerator) makeAPIGatewayListeners(address string, cfgSnap *pro
} }
listener := makeListener(listenerOpts) listener := makeListener(listenerOpts)
route, _ := cfgSnap.APIGateway.HTTPRoutes.Get(readyListener.routeReference) routes := make([]*structs.HTTPRouteConfigEntry, 0, len(readyListener.routeReferences))
foundJWT := false for _, routeRef := range maps.Keys(readyListener.routeReferences) {
if listenerCfg.Override != nil && listenerCfg.Override.JWT != nil { route, _ := cfgSnap.APIGateway.HTTPRoutes.Get(routeRef)
foundJWT = true routes = append(routes, route)
} }
consolidatedRoutes := discoverychain.ConsolidateHTTPRoutes(cfgSnap.APIGateway.GatewayConfig, &readyListener.listenerCfg, routes...)
routesWithJWT := []*structs.HTTPRouteConfigEntry{}
for _, routeCfgEntry := range consolidatedRoutes {
routeCfgEntry := routeCfgEntry
route := &routeCfgEntry
if !foundJWT && listenerCfg.Default != nil && listenerCfg.Default.JWT != nil { if listenerCfg.Override != nil && listenerCfg.Override.JWT != nil {
foundJWT = true routesWithJWT = append(routesWithJWT, route)
} continue
}
if listenerCfg.Default != nil && listenerCfg.Default.JWT != nil {
routesWithJWT = append(routesWithJWT, route)
continue
}
if !foundJWT {
for _, rule := range route.Rules { for _, rule := range route.Rules {
if rule.Filters.JWT != nil { if rule.Filters.JWT != nil {
foundJWT = true routesWithJWT = append(routesWithJWT, route)
break continue
} }
for _, svc := range rule.Services { for _, svc := range rule.Services {
if svc.Filters.JWT != nil { if svc.Filters.JWT != nil {
foundJWT = true routesWithJWT = append(routesWithJWT, route)
break continue
} }
} }
} }
} }
var authFilters []*envoy_http_v3.HttpFilter var authFilters []*envoy_http_v3.HttpFilter
if foundJWT { if len(routesWithJWT) > 0 {
builder := &GatewayAuthFilterBuilder{ builder := &GatewayAuthFilterBuilder{
listener: listenerCfg, listener: listenerCfg,
route: route, routes: routesWithJWT,
providers: cfgSnap.JWTProviders, providers: cfgSnap.JWTProviders,
envoyProviders: make(map[string]*envoy_http_jwt_authn_v3.JwtProvider, len(cfgSnap.JWTProviders)), envoyProviders: make(map[string]*envoy_http_jwt_authn_v3.JwtProvider, len(cfgSnap.JWTProviders)),
} }
@ -179,6 +199,7 @@ func (s *ResourceGenerator) makeAPIGatewayListeners(address string, cfgSnap *pro
return nil, err return nil, err
} }
} }
filterOpts := listenerFilterOpts{ filterOpts := listenerFilterOpts{
useRDS: true, useRDS: true,
protocol: listenerKey.Protocol, protocol: listenerKey.Protocol,
@ -246,7 +267,7 @@ type readyListener struct {
listenerKey proxycfg.APIGatewayListenerKey listenerKey proxycfg.APIGatewayListenerKey
listenerCfg structs.APIGatewayListener listenerCfg structs.APIGatewayListener
boundListenerCfg structs.BoundAPIGatewayListener boundListenerCfg structs.BoundAPIGatewayListener
routeReference structs.ResourceReference routeReferences map[structs.ResourceReference]struct{}
upstreams []structs.Upstream upstreams []structs.Upstream
} }
@ -285,10 +306,11 @@ func getReadyListeners(cfgSnap *proxycfg.ConfigSnapshot) map[string]readyListene
r = readyListener{ r = readyListener{
listenerKey: listenerKey, listenerKey: listenerKey,
listenerCfg: l, listenerCfg: l,
routeReferences: map[structs.ResourceReference]struct{}{},
boundListenerCfg: boundListener, boundListenerCfg: boundListener,
routeReference: routeRef,
} }
} }
r.routeReferences[routeRef] = struct{}{}
r.upstreams = append(r.upstreams, upstream) r.upstreams = append(r.upstreams, upstream)
ready[routeKey] = r ready[routeKey] = r
} }
@ -297,7 +319,10 @@ func getReadyListeners(cfgSnap *proxycfg.ConfigSnapshot) map[string]readyListene
return ready return ready
} }
func makeDownstreamTLSContextFromSnapshotAPIListenerConfig(cfgSnap *proxycfg.ConfigSnapshot, listenerCfg structs.APIGatewayListener) (*envoy_tls_v3.DownstreamTlsContext, error) { func makeDownstreamTLSContextFromSnapshotAPIListenerConfig(
cfgSnap *proxycfg.ConfigSnapshot,
listenerCfg structs.APIGatewayListener,
) (*envoy_tls_v3.DownstreamTlsContext, error) {
var downstreamContext *envoy_tls_v3.DownstreamTlsContext var downstreamContext *envoy_tls_v3.DownstreamTlsContext
tlsContext, err := makeCommonTLSContextFromSnapshotAPIGatewayListenerConfig(cfgSnap, listenerCfg) tlsContext, err := makeCommonTLSContextFromSnapshotAPIGatewayListenerConfig(cfgSnap, listenerCfg)
@ -318,7 +343,10 @@ func makeDownstreamTLSContextFromSnapshotAPIListenerConfig(cfgSnap *proxycfg.Con
return downstreamContext, nil return downstreamContext, nil
} }
func makeCommonTLSContextFromSnapshotAPIGatewayListenerConfig(cfgSnap *proxycfg.ConfigSnapshot, listenerCfg structs.APIGatewayListener) (*envoy_tls_v3.CommonTlsContext, error) { func makeCommonTLSContextFromSnapshotAPIGatewayListenerConfig(
cfgSnap *proxycfg.ConfigSnapshot,
listenerCfg structs.APIGatewayListener,
) (*envoy_tls_v3.CommonTlsContext, error) {
var tlsContext *envoy_tls_v3.CommonTlsContext var tlsContext *envoy_tls_v3.CommonTlsContext
// API Gateway TLS config is per listener // API Gateway TLS config is per listener

View File

@ -14,6 +14,8 @@ import (
envoy_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" envoy_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3"
envoy_matcher_v3 "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" envoy_matcher_v3 "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/wrapperspb" "google.golang.org/protobuf/types/known/wrapperspb"
@ -433,65 +435,57 @@ func (s *ResourceGenerator) routesForIngressGateway(cfgSnap *proxycfg.ConfigSnap
func (s *ResourceGenerator) routesForAPIGateway(cfgSnap *proxycfg.ConfigSnapshot) ([]proto.Message, error) { func (s *ResourceGenerator) routesForAPIGateway(cfgSnap *proxycfg.ConfigSnapshot) ([]proto.Message, error) {
var result []proto.Message var result []proto.Message
readyUpstreamsList := getReadyListeners(cfgSnap) readyListeners := getReadyListeners(cfgSnap)
for _, readyUpstreams := range readyUpstreamsList { for _, readyListener := range readyListeners {
readyUpstreams := readyUpstreams
listenerCfg := readyUpstreams.listenerCfg
// Do not create any route configuration for TCP listeners // Do not create any route configuration for TCP listeners
if listenerCfg.Protocol != structs.ListenerProtocolHTTP { if readyListener.listenerCfg.Protocol != structs.ListenerProtocolHTTP {
continue continue
} }
routeRef := readyUpstreams.routeReference listenerRoute := &envoy_route_v3.RouteConfiguration{
listenerKey := readyUpstreams.listenerKey Name: readyListener.listenerKey.RouteName(),
defaultRoute := &envoy_route_v3.RouteConfiguration{
Name: listenerKey.RouteName(),
// ValidateClusters defaults to true when defined statically and false // ValidateClusters defaults to true when defined statically and false
// when done via RDS. Re-set the reasonable value of true to prevent // when done via RDS. Re-set the reasonable value of true to prevent
// null-routing traffic. // null-routing traffic.
ValidateClusters: response.MakeBoolValue(true), ValidateClusters: response.MakeBoolValue(true),
} }
route, ok := cfgSnap.APIGateway.HTTPRoutes.Get(routeRef) // Consolidate all routes for this listener into the minimum possible set based on hostname matching.
if !ok { allRoutesForListener := []*structs.HTTPRouteConfigEntry{}
return nil, fmt.Errorf("missing route for route reference %s:%s", routeRef.Name, routeRef.Kind) for _, routeRef := range maps.Keys(readyListener.routeReferences) {
route, ok := cfgSnap.APIGateway.HTTPRoutes.Get(routeRef)
if !ok {
return nil, fmt.Errorf("missing route for route routeRef %s:%s", routeRef.Name, routeRef.Kind)
}
allRoutesForListener = append(allRoutesForListener, route)
} }
consolidatedRoutes := discoverychain.ConsolidateHTTPRoutes(cfgSnap.APIGateway.GatewayConfig, &readyListener.listenerCfg, allRoutesForListener...)
// Reformat the route here since discovery chains were indexed earlier using the // Produce one virtual host per hostname. If no hostname is specified for a set of
// specific naming convention in discoverychain.consolidateHTTPRoutes. If we don't // Gateway + HTTPRoutes, then the virtual host will be "*".
// convert our route to use the same naming convention, we won't find any chains below. for _, consolidatedRoute := range consolidatedRoutes {
reformatedRoutes := discoverychain.ReformatHTTPRoute(route, &listenerCfg, cfgSnap.APIGateway.GatewayConfig) upstream := buildHTTPRouteUpstream(consolidatedRoute, readyListener.listenerCfg)
filterBuilder := perRouteFilterBuilder{providerMap: cfgSnap.JWTProviders, listener: &listenerCfg, route: route}
for _, reformatedRoute := range reformatedRoutes {
reformatedRoute := reformatedRoute
upstream := buildHTTPRouteUpstream(reformatedRoute, listenerCfg)
uid := proxycfg.NewUpstreamID(&upstream) uid := proxycfg.NewUpstreamID(&upstream)
chain := cfgSnap.APIGateway.DiscoveryChain[uid] chain := cfgSnap.APIGateway.DiscoveryChain[uid]
if chain == nil { if chain == nil {
// Note that if we continue here we must also do this in the cluster generation
s.Logger.Debug("Discovery chain not found for flattened route", "discovery chain ID", uid) s.Logger.Debug("Discovery chain not found for flattened route", "discovery chain ID", uid)
continue continue
} }
domains := generateUpstreamAPIsDomains(listenerKey, upstream, reformatedRoute.Hostnames) domains := generateUpstreamAPIsDomains(readyListener.listenerKey, upstream, consolidatedRoute.Hostnames)
filterBuilder := perRouteFilterBuilder{providerMap: cfgSnap.JWTProviders, listener: &readyListener.listenerCfg, route: &consolidatedRoute}
virtualHost, err := s.makeUpstreamRouteForDiscoveryChain(cfgSnap, uid, chain, domains, false, filterBuilder) virtualHost, err := s.makeUpstreamRouteForDiscoveryChain(cfgSnap, uid, chain, domains, false, filterBuilder)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if virtualHost == nil {
continue
}
defaultRoute.VirtualHosts = append(defaultRoute.VirtualHosts, virtualHost) listenerRoute.VirtualHosts = append(listenerRoute.VirtualHosts, virtualHost)
} }
if len(defaultRoute.VirtualHosts) > 0 { if len(listenerRoute.VirtualHosts) > 0 {
result = append(result, defaultRoute) result = append(result, listenerRoute)
} }
} }