consul/connect/resolver_test.go
R.B. Boyer b60d89e7ef bulk rewrite using this script
set -euo pipefail

    unset CDPATH

    cd "$(dirname "$0")"

    for f in $(git grep '\brequire := require\.New(' | cut -d':' -f1 | sort -u); do
        echo "=== require: $f ==="
        sed -i '/require := require.New(t)/d' $f
        # require.XXX(blah) but not require.XXX(tblah) or require.XXX(rblah)
        sed -i 's/\brequire\.\([a-zA-Z0-9_]*\)(\([^tr]\)/require.\1(t,\2/g' $f
        # require.XXX(tblah) but not require.XXX(t, blah)
        sed -i 's/\brequire\.\([a-zA-Z0-9_]*\)(\(t[^,]\)/require.\1(t,\2/g' $f
        # require.XXX(rblah) but not require.XXX(r, blah)
        sed -i 's/\brequire\.\([a-zA-Z0-9_]*\)(\(r[^,]\)/require.\1(t,\2/g' $f
        gofmt -s -w $f
    done

    for f in $(git grep '\bassert := assert\.New(' | cut -d':' -f1 | sort -u); do
        echo "=== assert: $f ==="
        sed -i '/assert := assert.New(t)/d' $f
        # assert.XXX(blah) but not assert.XXX(tblah) or assert.XXX(rblah)
        sed -i 's/\bassert\.\([a-zA-Z0-9_]*\)(\([^tr]\)/assert.\1(t,\2/g' $f
        # assert.XXX(tblah) but not assert.XXX(t, blah)
        sed -i 's/\bassert\.\([a-zA-Z0-9_]*\)(\(t[^,]\)/assert.\1(t,\2/g' $f
        # assert.XXX(rblah) but not assert.XXX(r, blah)
        sed -i 's/\bassert\.\([a-zA-Z0-9_]*\)(\(r[^,]\)/assert.\1(t,\2/g' $f
        gofmt -s -w $f
    done
2022-01-20 10:46:23 -06:00

337 lines
7.9 KiB
Go

package connect
import (
"context"
"testing"
"time"
"github.com/hashicorp/consul/agent"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/api"
"github.com/stretchr/testify/require"
)
func TestStaticResolver_Resolve(t *testing.T) {
type fields struct {
Addr string
CertURI connect.CertURI
}
tests := []struct {
name string
fields fields
}{
{
name: "simples",
fields: fields{"1.2.3.4:80", connect.TestSpiffeIDService(t, "foo")},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sr := StaticResolver{
Addr: tt.fields.Addr,
CertURI: tt.fields.CertURI,
}
addr, certURI, err := sr.Resolve(context.Background())
require.Nil(t, err)
require.Equal(t, sr.Addr, addr)
require.Equal(t, sr.CertURI, certURI)
})
}
}
func TestConsulResolver_Resolve(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
// Setup a local test agent to query
agent := agent.StartTestAgent(t, agent.TestAgent{Name: "test-consul"})
defer agent.Shutdown()
cfg := api.DefaultConfig()
cfg.Address = agent.HTTPAddr()
client, err := api.NewClient(cfg)
require.Nil(t, err)
// Setup a service with a connect proxy instance
regSrv := &api.AgentServiceRegistration{
Name: "web",
Port: 8080,
}
err = client.Agent().ServiceRegister(regSrv)
require.Nil(t, err)
regProxy := &api.AgentServiceRegistration{
Kind: "connect-proxy",
Name: "web-proxy",
Port: 9090,
Proxy: &api.AgentServiceConnectProxyConfig{
DestinationServiceName: "web",
},
}
err = client.Agent().ServiceRegister(regProxy)
require.Nil(t, err)
// And another proxy so we can test handling with multiple endpoints returned
regProxy.Port = 9091
regProxy.ID = "web-proxy-2"
err = client.Agent().ServiceRegister(regProxy)
require.Nil(t, err)
// Add a native service
{
regSrv := &api.AgentServiceRegistration{
Name: "db",
Port: 8080,
Connect: &api.AgentServiceConnect{
Native: true,
},
}
require.NoError(t, client.Agent().ServiceRegister(regSrv))
}
// Add a prepared query
queryId, _, err := client.PreparedQuery().Create(&api.PreparedQueryDefinition{
Name: "test-query",
Service: api.ServiceQuery{
Service: "web",
Connect: true,
},
}, nil)
require.NoError(t, err)
proxyAddrs := []string{
agent.Config.AdvertiseAddrLAN.String() + ":9090",
agent.Config.AdvertiseAddrLAN.String() + ":9091",
}
type fields struct {
Namespace string
Name string
Type int
Datacenter string
}
tests := []struct {
name string
fields fields
timeout time.Duration
wantAddr string
wantCertURI connect.CertURI
wantErr bool
addrs []string
}{
{
name: "basic service discovery",
fields: fields{
Namespace: "default",
Name: "web",
Type: ConsulResolverTypeService,
},
// Want empty host since we don't enforce trust domain outside of TLS and
// don't need to load the current one this way.
wantCertURI: connect.TestSpiffeIDServiceWithHost(t, "web", ""),
wantErr: false,
addrs: proxyAddrs,
},
{
name: "basic service with native service",
fields: fields{
Namespace: "default",
Name: "db",
Type: ConsulResolverTypeService,
},
// Want empty host since we don't enforce trust domain outside of TLS and
// don't need to load the current one this way.
wantCertURI: connect.TestSpiffeIDServiceWithHost(t, "db", ""),
wantErr: false,
},
{
name: "Bad Type errors",
fields: fields{
Namespace: "default",
Name: "web",
Type: 123,
},
wantErr: true,
},
{
name: "Non-existent service errors",
fields: fields{
Namespace: "default",
Name: "foo",
Type: ConsulResolverTypeService,
},
wantErr: true,
},
{
name: "timeout errors",
fields: fields{
Namespace: "default",
Name: "web",
Type: ConsulResolverTypeService,
},
timeout: 1 * time.Nanosecond,
wantErr: true,
},
{
name: "prepared query by id",
fields: fields{
Name: queryId,
Type: ConsulResolverTypePreparedQuery,
},
// Want empty host since we don't enforce trust domain outside of TLS and
// don't need to load the current one this way.
wantCertURI: connect.TestSpiffeIDServiceWithHost(t, "web", ""),
wantErr: false,
addrs: proxyAddrs,
},
{
name: "prepared query by name",
fields: fields{
Name: "test-query",
Type: ConsulResolverTypePreparedQuery,
},
// Want empty host since we don't enforce trust domain outside of TLS and
// don't need to load the current one this way.
wantCertURI: connect.TestSpiffeIDServiceWithHost(t, "web", ""),
wantErr: false,
addrs: proxyAddrs,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cr := &ConsulResolver{
Client: client,
Namespace: tt.fields.Namespace,
Name: tt.fields.Name,
Type: tt.fields.Type,
Datacenter: tt.fields.Datacenter,
}
// WithCancel just to have a cancel func in scope to assign in the if
// clause.
ctx, cancel := context.WithCancel(context.Background())
if tt.timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, tt.timeout)
}
defer cancel()
gotAddr, gotCertURI, err := cr.Resolve(ctx)
if tt.wantErr {
require.NotNil(t, err)
return
}
require.Nil(t, err)
require.Equal(t, tt.wantCertURI, gotCertURI)
if len(tt.addrs) > 0 {
require.Contains(t, tt.addrs, gotAddr)
}
})
}
}
func TestConsulResolverFromAddrFunc(t *testing.T) {
// Don't need an actual instance since we don't do the service discovery but
// we do want to assert the client is pass through correctly.
client, err := api.NewClient(api.DefaultConfig())
require.NoError(t, err)
tests := []struct {
name string
addr string
want Resolver
wantErr string
}{
{
name: "service",
addr: "foo.service.consul",
want: &ConsulResolver{
Client: client,
Namespace: "default",
Name: "foo",
Type: ConsulResolverTypeService,
},
},
{
name: "query",
addr: "foo.query.consul",
want: &ConsulResolver{
Client: client,
Namespace: "default",
Name: "foo",
Type: ConsulResolverTypePreparedQuery,
},
},
{
name: "service with dc",
addr: "foo.service.dc2.consul",
want: &ConsulResolver{
Client: client,
Datacenter: "dc2",
Namespace: "default",
Name: "foo",
Type: ConsulResolverTypeService,
},
},
{
name: "query with dc",
addr: "foo.query.dc2.consul",
want: &ConsulResolver{
Client: client,
Datacenter: "dc2",
Namespace: "default",
Name: "foo",
Type: ConsulResolverTypePreparedQuery,
},
},
{
name: "invalid host:port",
addr: "%%%",
wantErr: "invalid Consul DNS domain",
},
{
name: "custom domain",
addr: "foo.service.my-consul.com",
wantErr: "invalid Consul DNS domain",
},
{
name: "unsupported query type",
addr: "foo.connect.consul",
wantErr: "unsupported Consul DNS domain",
},
{
name: "unsupported query type and datacenter",
addr: "foo.connect.dc1.consul",
wantErr: "unsupported Consul DNS domain",
},
{
name: "unsupported query type and datacenter",
addr: "foo.connect.dc1.consul",
wantErr: "unsupported Consul DNS domain",
},
{
name: "unsupported tag filter",
addr: "tag1.foo.service.consul",
wantErr: "unsupported Consul DNS domain",
},
{
name: "unsupported tag filter with DC",
addr: "tag1.foo.service.dc1.consul",
wantErr: "unsupported Consul DNS domain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn := ConsulResolverFromAddrFunc(client)
got, gotErr := fn(tt.addr)
if tt.wantErr != "" {
require.Error(t, gotErr)
require.Contains(t, gotErr.Error(), tt.wantErr)
} else {
require.NoError(t, gotErr)
require.Equal(t, tt.want, got)
}
})
}
}