Refactor resolveListenerSDSConfig to pass in whole config

This commit is contained in:
Paul Banks 2021-10-08 12:33:20 +01:00
parent d779a4fc2c
commit 6faf85bccd
2 changed files with 8 additions and 14 deletions

View File

@ -28,7 +28,7 @@ func (s *ResourceGenerator) makeIngressGatewayListeners(address string, cfgSnap
connectTLSEnabled := cfgSnap.IngressGateway.TLSConfig.Enabled || connectTLSEnabled := cfgSnap.IngressGateway.TLSConfig.Enabled ||
(listenerCfg.TLS != nil && listenerCfg.TLS.Enabled) (listenerCfg.TLS != nil && listenerCfg.TLS.Enabled)
sdsCfg, err := resolveListenerSDSConfig(cfgSnap, listenerKey) sdsCfg, err := resolveListenerSDSConfig(cfgSnap, listenerCfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -127,7 +127,7 @@ func (s *ResourceGenerator) makeIngressGatewayListeners(address string, cfgSnap
return resources, nil return resources, nil
} }
func resolveListenerSDSConfig(cfgSnap *proxycfg.ConfigSnapshot, listenerKey proxycfg.IngressListenerKey) (*structs.GatewayTLSSDSConfig, error) { func resolveListenerSDSConfig(cfgSnap *proxycfg.ConfigSnapshot, listenerCfg structs.IngressListener) (*structs.GatewayTLSSDSConfig, error) {
var mergedCfg structs.GatewayTLSSDSConfig var mergedCfg structs.GatewayTLSSDSConfig
gwSDS := cfgSnap.IngressGateway.TLSConfig.SDS gwSDS := cfgSnap.IngressGateway.TLSConfig.SDS
@ -136,11 +136,6 @@ func resolveListenerSDSConfig(cfgSnap *proxycfg.ConfigSnapshot, listenerKey prox
mergedCfg.CertResource = gwSDS.CertResource mergedCfg.CertResource = gwSDS.CertResource
} }
listenerCfg, ok := cfgSnap.IngressGateway.Listeners[listenerKey]
if !ok {
return nil, fmt.Errorf("no listener config found for listener on port %d", listenerKey.Port)
}
if listenerCfg.TLS != nil && listenerCfg.TLS.SDS != nil { if listenerCfg.TLS != nil && listenerCfg.TLS.SDS != nil {
if listenerCfg.TLS.SDS.ClusterName != "" { if listenerCfg.TLS.SDS.ClusterName != "" {
mergedCfg.ClusterName = listenerCfg.TLS.SDS.ClusterName mergedCfg.ClusterName = listenerCfg.TLS.SDS.ClusterName
@ -161,10 +156,10 @@ func resolveListenerSDSConfig(cfgSnap *proxycfg.ConfigSnapshot, listenerKey prox
return &mergedCfg, nil return &mergedCfg, nil
case mergedCfg.ClusterName == "" && mergedCfg.CertResource != "": case mergedCfg.ClusterName == "" && mergedCfg.CertResource != "":
return nil, fmt.Errorf("missing SDS cluster name for listener on port %d", listenerKey.Port) return nil, fmt.Errorf("missing SDS cluster name for listener on port %d", listenerCfg.Port)
case mergedCfg.ClusterName != "" && mergedCfg.CertResource == "": case mergedCfg.ClusterName != "" && mergedCfg.CertResource == "":
return nil, fmt.Errorf("missing SDS cert resource for listener on port %d", listenerKey.Port) return nil, fmt.Errorf("missing SDS cert resource for listener on port %d", listenerCfg.Port)
} }
return &mergedCfg, nil return &mergedCfg, nil

View File

@ -1172,7 +1172,7 @@ func TestResolveListenerSDSConfig(t *testing.T) {
snap := proxycfg.TestConfigSnapshotIngressWithGatewaySDS(t) snap := proxycfg.TestConfigSnapshotIngressWithGatewaySDS(t)
// Override TLS configs // Override TLS configs
snap.IngressGateway.TLSConfig.SDS = tc.gwSDS snap.IngressGateway.TLSConfig.SDS = tc.gwSDS
var key proxycfg.IngressListenerKey var listenerCfg structs.IngressListener
for k, lisCfg := range snap.IngressGateway.Listeners { for k, lisCfg := range snap.IngressGateway.Listeners {
if tc.lisSDS == nil { if tc.lisSDS == nil {
lisCfg.TLS = nil lisCfg.TLS = nil
@ -1183,12 +1183,11 @@ func TestResolveListenerSDSConfig(t *testing.T) {
} }
// Override listener cfg in map // Override listener cfg in map
snap.IngressGateway.Listeners[k] = lisCfg snap.IngressGateway.Listeners[k] = lisCfg
// Save the last key doesn't matter which as we set same listener config // Save the last cfg doesn't matter which as we set same for all.
// for all. listenerCfg = lisCfg
key = k
} }
got, err := resolveListenerSDSConfig(snap, key) got, err := resolveListenerSDSConfig(snap, listenerCfg)
if tc.wantErr != "" { if tc.wantErr != "" {
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), tc.wantErr) require.Contains(t, err.Error(), tc.wantErr)