2023-08-11 13:12:13 +00:00
|
|
|
// Copyright (c) HashiCorp, Inc.
|
|
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
|
2023-05-23 11:55:06 +00:00
|
|
|
package extensioncommon
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"testing"
|
|
|
|
|
|
|
|
"github.com/hashicorp/consul/api"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestUpstreamConfigSourceLimitations(t *testing.T) {
|
|
|
|
type testCase struct {
|
|
|
|
extender EnvoyExtender
|
|
|
|
config *RuntimeConfig
|
|
|
|
ok bool
|
|
|
|
errMsg string
|
|
|
|
}
|
|
|
|
cases := map[string]testCase{
|
|
|
|
"upstream extender non-upstream config": {
|
|
|
|
extender: &UpstreamEnvoyExtender{},
|
|
|
|
config: &RuntimeConfig{
|
|
|
|
Kind: api.ServiceKindConnectProxy,
|
|
|
|
ServiceName: api.CompoundServiceName{Name: "api"},
|
|
|
|
Upstreams: map[api.CompoundServiceName]*UpstreamData{},
|
|
|
|
IsSourcedFromUpstream: false,
|
|
|
|
EnvoyExtension: api.EnvoyExtension{
|
|
|
|
Name: api.BuiltinAWSLambdaExtension,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
ok: false,
|
|
|
|
errMsg: fmt.Sprintf("%q extension applied as upstream config but is not sourced from an upstream of the local service", api.BuiltinAWSLambdaExtension),
|
|
|
|
},
|
|
|
|
"basic extender upstream config": {
|
|
|
|
extender: &BasicEnvoyExtender{},
|
|
|
|
config: &RuntimeConfig{
|
|
|
|
Kind: api.ServiceKindConnectProxy,
|
|
|
|
ServiceName: api.CompoundServiceName{Name: "api"},
|
|
|
|
Upstreams: map[api.CompoundServiceName]*UpstreamData{},
|
|
|
|
IsSourcedFromUpstream: true,
|
|
|
|
EnvoyExtension: api.EnvoyExtension{
|
|
|
|
Name: api.BuiltinLuaExtension,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
ok: false,
|
|
|
|
errMsg: fmt.Sprintf("%q extension applied as local config but is sourced from an upstream of the local service", api.BuiltinLuaExtension),
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for n, tc := range cases {
|
|
|
|
t.Run(n, func(t *testing.T) {
|
|
|
|
_, err := tc.extender.Extend(nil, tc.config)
|
|
|
|
if tc.ok {
|
|
|
|
require.NoError(t, err)
|
|
|
|
} else {
|
|
|
|
require.Error(t, err)
|
|
|
|
require.ErrorContains(t, err, tc.errMsg)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|