consul/agent/grpc-internal/client_test.go
Dan Upton b9e525d689
grpc: rename public/private directories to external/internal (#13721)
Previously, public referred to gRPC services that are both exposed on
the dedicated gRPC port and have their definitions in the proto-public
directory (so were considered usable by 3rd parties). Whereas private
referred to services on the multiplexed server port that are only usable
by agents and other servers.

Now, we're splitting these definitions, such that external/internal
refers to the port and public/private refers to whether they can be used
by 3rd parties.

This is necessary because the peering replication API needs to be
exposed on the dedicated port, but is not (yet) suitable for use by 3rd
parties.
2022-07-13 16:33:48 +01:00

440 lines
13 KiB
Go

package internal
import (
"context"
"fmt"
"net"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/google/tcpproxy"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/grpc-internal/internal/testservice"
"github.com/hashicorp/consul/agent/grpc-internal/resolver"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/consul/types"
)
// useTLSForDcAlwaysTrue tell GRPC to always return the TLS is enabled
func useTLSForDcAlwaysTrue(_ string) bool {
return true
}
func TestNewDialer_WithTLSWrapper(t *testing.T) {
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(logError(t, lis.Close))
builder := resolver.NewServerResolverBuilder(newConfig(t))
builder.AddServer(types.AreaWAN, &metadata.Server{
Name: "server-1",
ID: "ID1",
Datacenter: "dc1",
Addr: lis.Addr(),
UseTLS: true,
})
var called bool
wrapper := func(_ string, conn net.Conn) (net.Conn, error) {
called = true
return conn, nil
}
dial := newDialer(
ClientConnPoolConfig{
Servers: builder,
TLSWrapper: wrapper,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
},
&gatewayResolverDep{},
)
ctx := context.Background()
conn, err := dial(ctx, resolver.DCPrefix("dc1", lis.Addr().String()))
require.NoError(t, err)
require.NoError(t, conn.Close())
require.True(t, called, "expected TLSWrapper to be called")
}
func TestNewDialer_WithALPNWrapper(t *testing.T) {
lis1, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(logError(t, lis1.Close))
lis2, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(logError(t, lis2.Close))
// Send all of the traffic to dc2's server
var p tcpproxy.Proxy
gwAddr := ipaddr.FormatAddressPort("127.0.0.1", freeport.GetOne(t))
p.AddRoute(gwAddr, tcpproxy.To(lis2.Addr().String()))
p.AddStopACMESearch(gwAddr)
require.NoError(t, p.Start())
defer func() {
p.Close()
p.Wait()
}()
builder := resolver.NewServerResolverBuilder(newConfig(t))
builder.AddServer(types.AreaWAN, &metadata.Server{
Name: "server-1",
ID: "ID1",
Datacenter: "dc1",
Addr: lis1.Addr(),
UseTLS: true,
})
builder.AddServer(types.AreaWAN, &metadata.Server{
Name: "server-2",
ID: "ID2",
Datacenter: "dc2",
Addr: lis2.Addr(),
UseTLS: true,
})
var calledTLS bool
wrapperTLS := func(_ string, conn net.Conn) (net.Conn, error) {
calledTLS = true
return conn, nil
}
var calledALPN bool
wrapperALPN := func(_, _, _ string, conn net.Conn) (net.Conn, error) {
calledALPN = true
return conn, nil
}
gwResolverDep := &gatewayResolverDep{
GatewayResolver: func(addr string) string {
return gwAddr
},
}
dial := newDialer(
ClientConnPoolConfig{
Servers: builder,
TLSWrapper: wrapperTLS,
ALPNWrapper: wrapperALPN,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
},
gwResolverDep,
)
ctx := context.Background()
conn, err := dial(ctx, resolver.DCPrefix("dc2", lis2.Addr().String()))
require.NoError(t, err)
require.NoError(t, conn.Close())
assert.False(t, calledTLS, "expected TLSWrapper not to be called")
assert.True(t, calledALPN, "expected ALPNWrapper to be called")
}
func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
// if this test is failing because of expired certificates
// use the procedure in test/CA-GENERATION.md
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
InternalRPC: tlsutil.ProtocolConfig{
VerifyIncoming: true,
CAFile: "../../test/hostname/CertAuth.crt",
CertFile: "../../test/hostname/Alice.crt",
KeyFile: "../../test/hostname/Alice.key",
VerifyOutgoing: true,
},
}, hclog.New(nil))
require.NoError(t, err)
srv := newSimpleTestServer(t, "server-1", "dc1", tlsConf)
md := srv.Metadata()
res.AddServer(types.AreaWAN, md)
t.Cleanup(srv.shutdown)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
TLSWrapper: TLSWrapper(tlsConf.OutgoingRPCWrapper()),
UseTLSForDC: tlsConf.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
client := testservice.NewSimpleClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
t.Cleanup(cancel)
resp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, "server-1", resp.ServerName)
require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) > 0)
require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) == 0)
}
func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) {
// if this test is failing because of expired certificates
// use the procedure in test/CA-GENERATION.md
gwAddr := ipaddr.FormatAddressPort("127.0.0.1", freeport.GetOne(t))
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
InternalRPC: tlsutil.ProtocolConfig{
VerifyIncoming: true,
CAFile: "../../test/hostname/CertAuth.crt",
CertFile: "../../test/hostname/Bob.crt",
KeyFile: "../../test/hostname/Bob.key",
VerifyOutgoing: true,
VerifyServerHostname: true,
},
Domain: "consul",
NodeName: "bob",
}, hclog.New(nil))
require.NoError(t, err)
srv := newSimpleTestServer(t, "bob", "dc1", tlsConf)
// Send all of the traffic to dc1's server
var p tcpproxy.Proxy
p.AddRoute(gwAddr, tcpproxy.To(srv.addr.String()))
p.AddStopACMESearch(gwAddr)
require.NoError(t, p.Start())
defer func() {
p.Close()
p.Wait()
}()
md := srv.Metadata()
res.AddServer(types.AreaWAN, md)
t.Cleanup(srv.shutdown)
clientTLSConf, err := tlsutil.NewConfigurator(tlsutil.Config{
InternalRPC: tlsutil.ProtocolConfig{
VerifyIncoming: true,
CAFile: "../../test/hostname/CertAuth.crt",
CertFile: "../../test/hostname/Betty.crt",
KeyFile: "../../test/hostname/Betty.key",
VerifyOutgoing: true,
VerifyServerHostname: true,
},
Domain: "consul",
NodeName: "betty",
}, hclog.New(nil))
require.NoError(t, err)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
TLSWrapper: TLSWrapper(clientTLSConf.OutgoingRPCWrapper()),
ALPNWrapper: ALPNWrapper(clientTLSConf.OutgoingALPNRPCWrapper()),
UseTLSForDC: tlsConf.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc2",
})
pool.SetGatewayResolver(func(addr string) string {
return gwAddr
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
client := testservice.NewSimpleClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
t.Cleanup(cancel)
resp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, "bob", resp.ServerName)
require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) == 0)
require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) > 0)
}
func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
count := 4
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i)
srv := newSimpleTestServer(t, name, "dc1", nil)
res.AddServer(types.AreaWAN, srv.Metadata())
t.Cleanup(srv.shutdown)
}
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
client := testservice.NewSimpleClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
t.Cleanup(cancel)
first, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
res.RemoveServer(types.AreaWAN, &metadata.Server{ID: first.ServerName, Datacenter: "dc1"})
resp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.NotEqual(t, resp.ServerName, first.ServerName)
}
func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
count := 3
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
var servers []testServer
for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i)
srv := newSimpleTestServer(t, name, "dc1", nil)
res.AddServer(types.AreaWAN, srv.Metadata())
servers = append(servers, srv)
t.Cleanup(srv.shutdown)
}
// Set the leader address to the first server.
srv0 := servers[0].Metadata()
res.UpdateLeaderAddr(srv0.Datacenter, srv0.Addr.String())
conn, err := pool.ClientConnLeader()
require.NoError(t, err)
client := testservice.NewSimpleClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
t.Cleanup(cancel)
first, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, first.ServerName, servers[0].name)
// Update the leader address and make another request.
srv1 := servers[1].Metadata()
res.UpdateLeaderAddr(srv1.Datacenter, srv1.Addr.String())
resp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, resp.ServerName, servers[1].name)
}
func newConfig(t *testing.T) resolver.Config {
n := t.Name()
s := strings.Replace(n, "/", "", -1)
s = strings.Replace(s, "_", "", -1)
return resolver.Config{Authority: strings.ToLower(s)}
}
func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) {
count := 5
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i)
srv := newSimpleTestServer(t, name, "dc1", nil)
res.AddServer(types.AreaWAN, srv.Metadata())
t.Cleanup(srv.shutdown)
}
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
client := testservice.NewSimpleClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
t.Cleanup(cancel)
first, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
t.Run("rebalance a different DC, does nothing", func(t *testing.T) {
res.NewRebalancer("dc-other")()
resp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, resp.ServerName, first.ServerName)
})
t.Run("rebalance the dc", func(t *testing.T) {
// Rebalance is random, but if we repeat it a few times it should give us a
// new server.
attempts := 100
for i := 0; i < attempts; i++ {
res.NewRebalancer("dc1")()
resp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
if resp.ServerName != first.ServerName {
return
}
}
t.Fatalf("server was not rebalanced after %v attempts", attempts)
})
}
func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
dcs := []string{"dc1", "dc2", "dc3"}
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
for _, dc := range dcs {
name := "server-0-" + dc
srv := newSimpleTestServer(t, name, dc, nil)
res.AddServer(types.AreaWAN, srv.Metadata())
t.Cleanup(srv.shutdown)
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
t.Cleanup(cancel)
for _, dc := range dcs {
conn, err := pool.ClientConn(dc)
require.NoError(t, err)
client := testservice.NewSimpleClient(conn)
resp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, resp.Datacenter, dc)
}
}
func registerWithGRPC(t *testing.T, b *resolver.ServerResolverBuilder) {
resolver.Register(b)
t.Cleanup(func() {
resolver.Deregister(b.Authority())
})
}