[NET-5457] Fix CE code for jwt multiple virtual hosts bug (#19123)

* Fix CE code for jwt multiple virtual hosts bug

* Fix struct definition

* fix bug with always appending route to jwt config

* Update comment to be correct

* Update comment
This commit is contained in:
John Maguire 2023-10-10 16:25:36 -04:00 committed by GitHub
parent 830c4ea81c
commit 8bebfc147d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 62 deletions

View File

@ -198,9 +198,14 @@ func (l *GatewayChainSynthesizer) consolidateHTTPRoutes() []structs.HTTPRouteCon
return consolidateHTTPRoutes(l.matchesByHostname, l.suffix, l.gateway)
}
// ReformatHTTPRoute takes in an HTTPRoute and reformats it to match the discovery chains generated by the gateway chain synthesizer
func ReformatHTTPRoute(route *structs.HTTPRouteConfigEntry, listener *structs.APIGatewayListener, gateway *structs.APIGatewayConfigEntry) []structs.HTTPRouteConfigEntry {
matches := initHostMatches(listener.GetHostname(), route, map[string][]hostnameMatch{})
// ConsolidateHTTPRoutes takes in one or more HTTPRoutes and consolidates them down to the minimum
// set of HTTPRoutes that can represent the same set of rules. This should result in approx. one
// HTTPRoute per hostname.
func ConsolidateHTTPRoutes(gateway *structs.APIGatewayConfigEntry, listener *structs.APIGatewayListener, routes ...*structs.HTTPRouteConfigEntry) []structs.HTTPRouteConfigEntry {
matches := map[string][]hostnameMatch{}
for _, route := range routes {
matches = initHostMatches(listener.GetHostname(), route, matches)
}
return consolidateHTTPRoutes(matches, listener.Name, gateway)
}

View File

@ -14,7 +14,7 @@ import (
type GatewayAuthFilterBuilder struct {
listener structs.APIGatewayListener
route *structs.HTTPRouteConfigEntry
routes []*structs.HTTPRouteConfigEntry
providers map[string]*structs.JWTProviderConfigEntry
envoyProviders map[string]*envoy_http_jwt_authn_v3.JwtProvider
}

View File

@ -6,12 +6,15 @@ package xds
import (
"fmt"
"golang.org/x/exp/maps"
envoy_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
envoy_http_jwt_authn_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/jwt_authn/v3"
envoy_http_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3"
envoy_tls_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3"
"github.com/hashicorp/consul/agent/consul/discoverychain"
"github.com/hashicorp/consul/agent/xds/naming"
"google.golang.org/protobuf/proto"
@ -107,15 +110,21 @@ func (s *ResourceGenerator) makeAPIGatewayListeners(address string, cfgSnap *pro
if isAPIGatewayWithTLS {
// construct SNI filter chains
l.FilterChains, err = makeInlineOverrideFilterChains(cfgSnap, cfgSnap.APIGateway.TLSConfig, listenerKey.Protocol, listenerFilterOpts{
useRDS: useRDS,
protocol: listenerKey.Protocol,
routeName: listenerKey.RouteName(),
cluster: clusterName,
statPrefix: "ingress_upstream_",
accessLogs: &cfgSnap.Proxy.AccessLogs,
logger: s.Logger,
}, certs)
l.FilterChains, err = makeInlineOverrideFilterChains(
cfgSnap,
cfgSnap.APIGateway.TLSConfig,
listenerKey.Protocol,
listenerFilterOpts{
useRDS: useRDS,
protocol: listenerKey.Protocol,
routeName: listenerKey.RouteName(),
cluster: clusterName,
statPrefix: "ingress_upstream_",
accessLogs: &cfgSnap.Proxy.AccessLogs,
logger: s.Logger,
},
certs,
)
if err != nil {
return nil, err
}
@ -141,36 +150,47 @@ func (s *ResourceGenerator) makeAPIGatewayListeners(address string, cfgSnap *pro
}
listener := makeListener(listenerOpts)
route, _ := cfgSnap.APIGateway.HTTPRoutes.Get(readyListener.routeReference)
foundJWT := false
if listenerCfg.Override != nil && listenerCfg.Override.JWT != nil {
foundJWT = true
routes := make([]*structs.HTTPRouteConfigEntry, 0, len(readyListener.routeReferences))
for _, routeRef := range maps.Keys(readyListener.routeReferences) {
route, _ := cfgSnap.APIGateway.HTTPRoutes.Get(routeRef)
routes = append(routes, route)
}
consolidatedRoutes := discoverychain.ConsolidateHTTPRoutes(cfgSnap.APIGateway.GatewayConfig, &readyListener.listenerCfg, routes...)
routesWithJWT := []*structs.HTTPRouteConfigEntry{}
for _, routeCfgEntry := range consolidatedRoutes {
routeCfgEntry := routeCfgEntry
route := &routeCfgEntry
if !foundJWT && listenerCfg.Default != nil && listenerCfg.Default.JWT != nil {
foundJWT = true
}
if listenerCfg.Override != nil && listenerCfg.Override.JWT != nil {
routesWithJWT = append(routesWithJWT, route)
continue
}
if listenerCfg.Default != nil && listenerCfg.Default.JWT != nil {
routesWithJWT = append(routesWithJWT, route)
continue
}
if !foundJWT {
for _, rule := range route.Rules {
if rule.Filters.JWT != nil {
foundJWT = true
break
routesWithJWT = append(routesWithJWT, route)
continue
}
for _, svc := range rule.Services {
if svc.Filters.JWT != nil {
foundJWT = true
break
routesWithJWT = append(routesWithJWT, route)
continue
}
}
}
}
var authFilters []*envoy_http_v3.HttpFilter
if foundJWT {
if len(routesWithJWT) > 0 {
builder := &GatewayAuthFilterBuilder{
listener: listenerCfg,
route: route,
routes: routesWithJWT,
providers: cfgSnap.JWTProviders,
envoyProviders: make(map[string]*envoy_http_jwt_authn_v3.JwtProvider, len(cfgSnap.JWTProviders)),
}
@ -179,6 +199,7 @@ func (s *ResourceGenerator) makeAPIGatewayListeners(address string, cfgSnap *pro
return nil, err
}
}
filterOpts := listenerFilterOpts{
useRDS: true,
protocol: listenerKey.Protocol,
@ -246,7 +267,7 @@ type readyListener struct {
listenerKey proxycfg.APIGatewayListenerKey
listenerCfg structs.APIGatewayListener
boundListenerCfg structs.BoundAPIGatewayListener
routeReference structs.ResourceReference
routeReferences map[structs.ResourceReference]struct{}
upstreams []structs.Upstream
}
@ -285,10 +306,11 @@ func getReadyListeners(cfgSnap *proxycfg.ConfigSnapshot) map[string]readyListene
r = readyListener{
listenerKey: listenerKey,
listenerCfg: l,
routeReferences: map[structs.ResourceReference]struct{}{},
boundListenerCfg: boundListener,
routeReference: routeRef,
}
}
r.routeReferences[routeRef] = struct{}{}
r.upstreams = append(r.upstreams, upstream)
ready[routeKey] = r
}
@ -297,7 +319,10 @@ func getReadyListeners(cfgSnap *proxycfg.ConfigSnapshot) map[string]readyListene
return ready
}
func makeDownstreamTLSContextFromSnapshotAPIListenerConfig(cfgSnap *proxycfg.ConfigSnapshot, listenerCfg structs.APIGatewayListener) (*envoy_tls_v3.DownstreamTlsContext, error) {
func makeDownstreamTLSContextFromSnapshotAPIListenerConfig(
cfgSnap *proxycfg.ConfigSnapshot,
listenerCfg structs.APIGatewayListener,
) (*envoy_tls_v3.DownstreamTlsContext, error) {
var downstreamContext *envoy_tls_v3.DownstreamTlsContext
tlsContext, err := makeCommonTLSContextFromSnapshotAPIGatewayListenerConfig(cfgSnap, listenerCfg)
@ -318,7 +343,10 @@ func makeDownstreamTLSContextFromSnapshotAPIListenerConfig(cfgSnap *proxycfg.Con
return downstreamContext, nil
}
func makeCommonTLSContextFromSnapshotAPIGatewayListenerConfig(cfgSnap *proxycfg.ConfigSnapshot, listenerCfg structs.APIGatewayListener) (*envoy_tls_v3.CommonTlsContext, error) {
func makeCommonTLSContextFromSnapshotAPIGatewayListenerConfig(
cfgSnap *proxycfg.ConfigSnapshot,
listenerCfg structs.APIGatewayListener,
) (*envoy_tls_v3.CommonTlsContext, error) {
var tlsContext *envoy_tls_v3.CommonTlsContext
// API Gateway TLS config is per listener

View File

@ -14,6 +14,8 @@ import (
envoy_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3"
envoy_matcher_v3 "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/wrapperspb"
@ -433,65 +435,57 @@ func (s *ResourceGenerator) routesForIngressGateway(cfgSnap *proxycfg.ConfigSnap
func (s *ResourceGenerator) routesForAPIGateway(cfgSnap *proxycfg.ConfigSnapshot) ([]proto.Message, error) {
var result []proto.Message
readyUpstreamsList := getReadyListeners(cfgSnap)
readyListeners := getReadyListeners(cfgSnap)
for _, readyUpstreams := range readyUpstreamsList {
readyUpstreams := readyUpstreams
listenerCfg := readyUpstreams.listenerCfg
for _, readyListener := range readyListeners {
// Do not create any route configuration for TCP listeners
if listenerCfg.Protocol != structs.ListenerProtocolHTTP {
if readyListener.listenerCfg.Protocol != structs.ListenerProtocolHTTP {
continue
}
routeRef := readyUpstreams.routeReference
listenerKey := readyUpstreams.listenerKey
defaultRoute := &envoy_route_v3.RouteConfiguration{
Name: listenerKey.RouteName(),
listenerRoute := &envoy_route_v3.RouteConfiguration{
Name: readyListener.listenerKey.RouteName(),
// ValidateClusters defaults to true when defined statically and false
// when done via RDS. Re-set the reasonable value of true to prevent
// null-routing traffic.
ValidateClusters: response.MakeBoolValue(true),
}
route, ok := cfgSnap.APIGateway.HTTPRoutes.Get(routeRef)
if !ok {
return nil, fmt.Errorf("missing route for route reference %s:%s", routeRef.Name, routeRef.Kind)
// Consolidate all routes for this listener into the minimum possible set based on hostname matching.
allRoutesForListener := []*structs.HTTPRouteConfigEntry{}
for _, routeRef := range maps.Keys(readyListener.routeReferences) {
route, ok := cfgSnap.APIGateway.HTTPRoutes.Get(routeRef)
if !ok {
return nil, fmt.Errorf("missing route for route routeRef %s:%s", routeRef.Name, routeRef.Kind)
}
allRoutesForListener = append(allRoutesForListener, route)
}
consolidatedRoutes := discoverychain.ConsolidateHTTPRoutes(cfgSnap.APIGateway.GatewayConfig, &readyListener.listenerCfg, allRoutesForListener...)
// Reformat the route here since discovery chains were indexed earlier using the
// specific naming convention in discoverychain.consolidateHTTPRoutes. If we don't
// convert our route to use the same naming convention, we won't find any chains below.
reformatedRoutes := discoverychain.ReformatHTTPRoute(route, &listenerCfg, cfgSnap.APIGateway.GatewayConfig)
filterBuilder := perRouteFilterBuilder{providerMap: cfgSnap.JWTProviders, listener: &listenerCfg, route: route}
for _, reformatedRoute := range reformatedRoutes {
reformatedRoute := reformatedRoute
upstream := buildHTTPRouteUpstream(reformatedRoute, listenerCfg)
// Produce one virtual host per hostname. If no hostname is specified for a set of
// Gateway + HTTPRoutes, then the virtual host will be "*".
for _, consolidatedRoute := range consolidatedRoutes {
upstream := buildHTTPRouteUpstream(consolidatedRoute, readyListener.listenerCfg)
uid := proxycfg.NewUpstreamID(&upstream)
chain := cfgSnap.APIGateway.DiscoveryChain[uid]
if chain == nil {
// Note that if we continue here we must also do this in the cluster generation
s.Logger.Debug("Discovery chain not found for flattened route", "discovery chain ID", uid)
continue
}
domains := generateUpstreamAPIsDomains(listenerKey, upstream, reformatedRoute.Hostnames)
domains := generateUpstreamAPIsDomains(readyListener.listenerKey, upstream, consolidatedRoute.Hostnames)
filterBuilder := perRouteFilterBuilder{providerMap: cfgSnap.JWTProviders, listener: &readyListener.listenerCfg, route: &consolidatedRoute}
virtualHost, err := s.makeUpstreamRouteForDiscoveryChain(cfgSnap, uid, chain, domains, false, filterBuilder)
if err != nil {
return nil, err
}
if virtualHost == nil {
continue
}
defaultRoute.VirtualHosts = append(defaultRoute.VirtualHosts, virtualHost)
listenerRoute.VirtualHosts = append(listenerRoute.VirtualHosts, virtualHost)
}
if len(defaultRoute.VirtualHosts) > 0 {
result = append(result, defaultRoute)
if len(listenerRoute.VirtualHosts) > 0 {
result = append(result, listenerRoute)
}
}