consul/envoyextensions/extensioncommon/envoy_extender_test.go

75 lines
2.3 KiB
Go

package extensioncommon
import (
"fmt"
"testing"
"github.com/hashicorp/consul/api"
"github.com/stretchr/testify/require"
)
func TestUpstreamConfigSourceLimitations(t *testing.T) {
type testCase struct {
extender EnvoyExtender
config *RuntimeConfig
ok bool
errMsg string
}
cases := map[string]testCase{
"upstream extender non-upstream config": {
extender: &UpstreamEnvoyExtender{},
config: &RuntimeConfig{
Kind: api.ServiceKindConnectProxy,
ServiceName: api.CompoundServiceName{Name: "api"},
Upstreams: map[api.CompoundServiceName]*UpstreamData{},
IsSourcedFromUpstream: false,
EnvoyExtension: api.EnvoyExtension{
Name: api.BuiltinAWSLambdaExtension,
},
},
ok: false,
errMsg: fmt.Sprintf("%q extension applied as upstream config but is not sourced from an upstream of the local service", api.BuiltinAWSLambdaExtension),
},
"basic extender upstream config": {
extender: &BasicEnvoyExtender{},
config: &RuntimeConfig{
Kind: api.ServiceKindConnectProxy,
ServiceName: api.CompoundServiceName{Name: "api"},
Upstreams: map[api.CompoundServiceName]*UpstreamData{},
IsSourcedFromUpstream: true,
EnvoyExtension: api.EnvoyExtension{
Name: api.BuiltinLuaExtension,
},
},
ok: false,
errMsg: fmt.Sprintf("%q extension applied as local config but is sourced from an upstream of the local service", api.BuiltinLuaExtension),
},
"list extender upstream config": {
extender: &ListEnvoyExtender{},
config: &RuntimeConfig{
Kind: api.ServiceKindConnectProxy,
ServiceName: api.CompoundServiceName{Name: "api"},
Upstreams: map[api.CompoundServiceName]*UpstreamData{},
IsSourcedFromUpstream: true,
EnvoyExtension: api.EnvoyExtension{
Name: api.BuiltinLuaExtension,
},
},
ok: false,
errMsg: fmt.Sprintf("%q extension applied as local config but is sourced from an upstream of the local service", api.BuiltinLuaExtension),
},
}
for n, tc := range cases {
t.Run(n, func(t *testing.T) {
_, err := tc.extender.Extend(nil, tc.config)
if tc.ok {
require.NoError(t, err)
} else {
require.Error(t, err)
require.ErrorContains(t, err, tc.errMsg)
}
})
}
}