mirror of https://github.com/status-im/consul.git
grpc: ensure that streaming gRPC requests work over mesh gateway based wan federation (#10838)
Fixes #10796
This commit is contained in:
parent
4993d877d9
commit
5b6d96d27d
|
@ -0,0 +1,3 @@
|
|||
```release-note:bug
|
||||
grpc: ensure that streaming gRPC requests work over mesh gateway based wan federation
|
||||
```
|
|
@ -371,6 +371,9 @@ func (f fakeGRPCConnPool) ClientConnLeader() (*grpc.ClientConn, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (f fakeGRPCConnPool) SetGatewayResolver(_ func(string) string) {
|
||||
}
|
||||
|
||||
func TestAgent_ReconnectConfigWanDisabled(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("too slow for testing.Short")
|
||||
|
@ -4524,6 +4527,9 @@ LOOP:
|
|||
}
|
||||
|
||||
// This is a mirror of a similar test in agent/consul/server_test.go
|
||||
//
|
||||
// TODO(rb): implement something similar to this as a full containerized test suite with proper
|
||||
// isolation so requests can't "cheat" and bypass the mesh gateways
|
||||
func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("too slow for testing.Short")
|
||||
|
@ -4771,6 +4777,9 @@ func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
|
|||
})
|
||||
|
||||
// Ensure we can do some trivial RPC in all directions.
|
||||
//
|
||||
// NOTE: we explicitly make streaming and non-streaming assertions here to
|
||||
// verify both rpc and grpc codepaths.
|
||||
agents := map[string]*TestAgent{"dc1": a1, "dc2": a2, "dc3": a3}
|
||||
names := map[string]string{"dc1": "bob", "dc2": "betty", "dc3": "bonnie"}
|
||||
for _, srcDC := range []string{"dc1", "dc2", "dc3"} {
|
||||
|
@ -4780,20 +4789,39 @@ func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
|
|||
continue
|
||||
}
|
||||
t.Run(srcDC+" to "+dstDC, func(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "/v1/catalog/nodes?dc="+dstDC, nil)
|
||||
require.NoError(t, err)
|
||||
t.Run("normal-rpc", func(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "/v1/catalog/nodes?dc="+dstDC, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.CatalogNodes(resp, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, obj)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.CatalogNodes(resp, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, obj)
|
||||
|
||||
nodes, ok := obj.(structs.Nodes)
|
||||
require.True(t, ok)
|
||||
require.Len(t, nodes, 1)
|
||||
node := nodes[0]
|
||||
require.Equal(t, dstDC, node.Datacenter)
|
||||
require.Equal(t, names[dstDC], node.Node)
|
||||
nodes, ok := obj.(structs.Nodes)
|
||||
require.True(t, ok)
|
||||
require.Len(t, nodes, 1)
|
||||
node := nodes[0]
|
||||
require.Equal(t, dstDC, node.Datacenter)
|
||||
require.Equal(t, names[dstDC], node.Node)
|
||||
})
|
||||
t.Run("streaming-grpc", func(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "/v1/health/service/consul?cached&dc="+dstDC, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.HealthServiceNodes(resp, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, obj)
|
||||
|
||||
csns, ok := obj.(structs.CheckServiceNodes)
|
||||
require.True(t, ok)
|
||||
require.Len(t, csns, 1)
|
||||
|
||||
csn := csns[0]
|
||||
require.Equal(t, dstDC, csn.Node.Datacenter)
|
||||
require.Equal(t, names[dstDC], csn.Node.Node)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,10 +5,17 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/hashicorp/consul/agent/grpc"
|
||||
"github.com/hashicorp/consul/agent/grpc/resolver"
|
||||
"github.com/hashicorp/consul/agent/pool"
|
||||
|
@ -20,11 +27,6 @@ import (
|
|||
"github.com/hashicorp/consul/sdk/testutil/retry"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func testClientConfig(t *testing.T) (string, *Config) {
|
||||
|
@ -490,6 +492,13 @@ func newClient(t *testing.T, config *Config) *Client {
|
|||
return client
|
||||
}
|
||||
|
||||
func newTestResolverConfig(t *testing.T, suffix string) resolver.Config {
|
||||
n := t.Name()
|
||||
s := strings.Replace(n, "/", "", -1)
|
||||
s = strings.Replace(s, "_", "", -1)
|
||||
return resolver.Config{Authority: strings.ToLower(s) + "-" + suffix}
|
||||
}
|
||||
|
||||
func newDefaultDeps(t *testing.T, c *Config) Deps {
|
||||
t.Helper()
|
||||
|
||||
|
@ -502,7 +511,7 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
|
|||
tls, err := tlsutil.NewConfigurator(c.TLSConfig, logger)
|
||||
require.NoError(t, err, "failed to create tls configuration")
|
||||
|
||||
builder := resolver.NewServerResolverBuilder(resolver.Config{Authority: c.NodeName})
|
||||
builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter))
|
||||
r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), builder)
|
||||
resolver.Register(builder)
|
||||
|
||||
|
@ -522,7 +531,13 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
|
|||
Tokens: new(token.Store),
|
||||
Router: r,
|
||||
ConnPool: connPool,
|
||||
GRPCConnPool: grpc.NewClientConnPool(builder, grpc.TLSWrapper(tls.OutgoingRPCWrapper()), tls.UseTLS),
|
||||
GRPCConnPool: grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
|
||||
Servers: builder,
|
||||
TLSWrapper: grpc.TLSWrapper(tls.OutgoingRPCWrapper()),
|
||||
UseTLSForDC: tls.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: c.Datacenter,
|
||||
}),
|
||||
LeaderForwarder: builder,
|
||||
EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c),
|
||||
}
|
||||
|
|
|
@ -24,8 +24,10 @@ type Deps struct {
|
|||
type GRPCClientConner interface {
|
||||
ClientConn(datacenter string) (*grpc.ClientConn, error)
|
||||
ClientConnLeader() (*grpc.ClientConn, error)
|
||||
SetGatewayResolver(func(string) string)
|
||||
}
|
||||
|
||||
type LeaderForwarder interface {
|
||||
UpdateLeaderAddr(leaderAddr string)
|
||||
// UpdateLeaderAddr updates the leader address in the local DC's resolver.
|
||||
UpdateLeaderAddr(datacenter, addr string)
|
||||
}
|
||||
|
|
|
@ -293,7 +293,7 @@ func (s *Server) handleNativeTLS(conn net.Conn) {
|
|||
s.handleSnapshotConn(tlsConn)
|
||||
|
||||
case pool.ALPN_RPCGRPC:
|
||||
s.grpcHandler.Handle(conn)
|
||||
s.grpcHandler.Handle(tlsConn)
|
||||
|
||||
case pool.ALPN_WANGossipPacket:
|
||||
if err := s.handleALPN_WANGossipPacketStream(tlsConn); err != nil && err != io.EOF {
|
||||
|
@ -373,7 +373,7 @@ func (s *Server) handleMultiplexV2(conn net.Conn) {
|
|||
}
|
||||
sub = peeked
|
||||
switch first {
|
||||
case pool.RPCGossip:
|
||||
case byte(pool.RPCGossip):
|
||||
buf := make([]byte, 1)
|
||||
sub.Read(buf)
|
||||
go func() {
|
||||
|
|
|
@ -460,7 +460,7 @@ func TestRPC_TLSHandshakeTimeout(t *testing.T) {
|
|||
|
||||
// Write TLS byte to avoid being closed by either the (outer) first byte
|
||||
// timeout or the fact that server requires TLS
|
||||
_, err = conn.Write([]byte{pool.RPCTLS})
|
||||
_, err = conn.Write([]byte{byte(pool.RPCTLS)})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for more than the timeout before we start a TLS handshake. This is
|
||||
|
|
|
@ -173,6 +173,9 @@ type Server struct {
|
|||
// Connection pool to other consul servers
|
||||
connPool *pool.ConnPool
|
||||
|
||||
// Connection pool to other consul servers using gRPC
|
||||
grpcConnPool GRPCClientConner
|
||||
|
||||
// eventChLAN is used to receive events from the
|
||||
// serf cluster in the datacenter
|
||||
eventChLAN chan serf.Event
|
||||
|
@ -348,6 +351,7 @@ func NewServer(config *Config, flat Deps) (*Server, error) {
|
|||
config: config,
|
||||
tokens: flat.Tokens,
|
||||
connPool: flat.ConnPool,
|
||||
grpcConnPool: flat.GRPCConnPool,
|
||||
eventChLAN: make(chan serf.Event, serfEventChSize),
|
||||
eventChWAN: make(chan serf.Event, serfEventChSize),
|
||||
logger: serverLogger,
|
||||
|
@ -377,6 +381,7 @@ func NewServer(config *Config, flat Deps) (*Server, error) {
|
|||
s.config.PrimaryDatacenter,
|
||||
)
|
||||
s.connPool.GatewayResolver = s.gatewayLocator.PickGateway
|
||||
s.grpcConnPool.SetGatewayResolver(s.gatewayLocator.PickGateway)
|
||||
}
|
||||
|
||||
// Initialize enterprise specific server functionality
|
||||
|
@ -1461,7 +1466,7 @@ func (s *Server) trackLeaderChanges() {
|
|||
continue
|
||||
}
|
||||
|
||||
s.grpcLeaderForwarder.UpdateLeaderAddr(string(leaderObs.Leader))
|
||||
s.grpcLeaderForwarder.UpdateLeaderAddr(s.config.Datacenter, string(leaderObs.Leader))
|
||||
case <-s.shutdownCh:
|
||||
s.raft.DeregisterObserver(observer)
|
||||
return
|
||||
|
|
|
@ -25,6 +25,8 @@ import (
|
|||
func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// TODO(rb): add tests for the wanfed/alpn variations
|
||||
|
||||
_, conf1 := testServerConfig(t)
|
||||
conf1.TLSConfig.VerifyIncoming = true
|
||||
conf1.TLSConfig.VerifyOutgoing = true
|
||||
|
@ -60,7 +62,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
|
|||
|
||||
// Start a Subscribe call to our streaming endpoint from the client.
|
||||
{
|
||||
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS)
|
||||
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
|
||||
Servers: builder,
|
||||
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
|
||||
UseTLSForDC: client.tlsConfigurator.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
})
|
||||
conn, err := pool.ClientConn("dc1")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -91,8 +99,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
|
|||
|
||||
// Start a Subscribe call to our streaming endpoint from the server's loopback client.
|
||||
{
|
||||
|
||||
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS)
|
||||
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
|
||||
Servers: builder,
|
||||
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
|
||||
UseTLSForDC: client.tlsConfigurator.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
})
|
||||
conn, err := pool.ClientConn("dc1")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -166,7 +179,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
|
|||
// Subscribe calls should fail initially
|
||||
joinLAN(t, client, server)
|
||||
|
||||
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS)
|
||||
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
|
||||
Servers: builder,
|
||||
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
|
||||
UseTLSForDC: client.tlsConfigurator.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
})
|
||||
conn, err := pool.ClientConn("dc1")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -294,7 +313,13 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
|
|||
}
|
||||
}()
|
||||
|
||||
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS)
|
||||
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
|
||||
Servers: builder,
|
||||
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
|
||||
UseTLSForDC: client.tlsConfigurator.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
})
|
||||
conn, err := pool.ClientConn("dc1")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -337,7 +362,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
|
|||
}
|
||||
|
||||
func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) {
|
||||
builder := resolver.NewServerResolverBuilder(resolver.Config{Authority: t.Name()})
|
||||
builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, "client"))
|
||||
resolver.Register(builder)
|
||||
t.Cleanup(func() {
|
||||
resolver.Deregister(builder.Authority())
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package wanfed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -11,7 +12,6 @@ import (
|
|||
"github.com/hashicorp/memberlist"
|
||||
|
||||
"github.com/hashicorp/consul/agent/pool"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
)
|
||||
|
||||
|
@ -97,13 +97,8 @@ func (t *Transport) WriteToAddress(b []byte, addr memberlist.Address) (time.Time
|
|||
}
|
||||
|
||||
if dc != t.datacenter {
|
||||
gwAddr := t.gwResolver(dc)
|
||||
if gwAddr == "" {
|
||||
return time.Time{}, structs.ErrDCNotAvailable
|
||||
}
|
||||
|
||||
dialFunc := func() (net.Conn, error) {
|
||||
return t.dial(dc, node, pool.ALPN_WANGossipPacket, gwAddr)
|
||||
return t.dial(dc, node, pool.ALPN_WANGossipPacket)
|
||||
}
|
||||
conn, err := t.pool.AcquireOrDial(addr.Name, dialFunc)
|
||||
if err != nil {
|
||||
|
@ -136,42 +131,24 @@ func (t *Transport) DialAddressTimeout(addr memberlist.Address, timeout time.Dur
|
|||
}
|
||||
|
||||
if dc != t.datacenter {
|
||||
gwAddr := t.gwResolver(dc)
|
||||
if gwAddr == "" {
|
||||
return nil, structs.ErrDCNotAvailable
|
||||
}
|
||||
|
||||
return t.dial(dc, node, pool.ALPN_WANGossipStream, gwAddr)
|
||||
return t.dial(dc, node, pool.ALPN_WANGossipStream)
|
||||
}
|
||||
|
||||
return t.IngestionAwareTransport.DialAddressTimeout(addr, timeout)
|
||||
}
|
||||
|
||||
// NOTE: There is a close mirror of this method in agent/pool/pool.go:DialTimeoutWithRPCType
|
||||
func (t *Transport) dial(dc, nodeName, nextProto, addr string) (net.Conn, error) {
|
||||
wrapper := t.tlsConfigurator.OutgoingALPNRPCWrapper()
|
||||
if wrapper == nil {
|
||||
return nil, fmt.Errorf("wanfed: cannot dial via a mesh gateway when outgoing TLS is disabled")
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
|
||||
rawConn, err := dialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tcp, ok := rawConn.(*net.TCPConn); ok {
|
||||
_ = tcp.SetKeepAlive(true)
|
||||
_ = tcp.SetNoDelay(true)
|
||||
}
|
||||
|
||||
tlsConn, err := wrapper(dc, nodeName, nextProto, rawConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
func (t *Transport) dial(dc, nodeName, nextProto string) (net.Conn, error) {
|
||||
conn, _, err := pool.DialRPCViaMeshGateway(
|
||||
context.Background(),
|
||||
dc,
|
||||
nodeName,
|
||||
nil, // TODO(rb): thread source address through here?
|
||||
t.tlsConfigurator.OutgoingALPNRPCWrapper(),
|
||||
nextProto,
|
||||
true,
|
||||
t.gwResolver,
|
||||
)
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// SplitNodeName splits a node name as it would be represented in
|
||||
|
|
|
@ -12,38 +12,93 @@ import (
|
|||
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/agent/pool"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
)
|
||||
|
||||
// ClientConnPool creates and stores a connection for each datacenter.
|
||||
type ClientConnPool struct {
|
||||
dialer dialer
|
||||
servers ServerLocator
|
||||
conns map[string]*grpc.ClientConn
|
||||
connsLock sync.Mutex
|
||||
dialer dialer
|
||||
servers ServerLocator
|
||||
gwResolverDep gatewayResolverDep
|
||||
conns map[string]*grpc.ClientConn
|
||||
connsLock sync.Mutex
|
||||
}
|
||||
|
||||
type ServerLocator interface {
|
||||
// ServerForAddr is used to look up server metadata from an address.
|
||||
ServerForAddr(addr string) (*metadata.Server, error)
|
||||
// ServerForGlobalAddr returns server metadata for a server with the specified globally unique address.
|
||||
ServerForGlobalAddr(globalAddr string) (*metadata.Server, error)
|
||||
|
||||
// Authority returns the target authority to use to dial the server. This is primarily
|
||||
// needed for testing multiple agents in parallel, because gRPC requires the
|
||||
// resolver to be registered globally.
|
||||
Authority() string
|
||||
}
|
||||
|
||||
// gatewayResolverDep is just a holder for a function pointer that can be
|
||||
// updated lazily after the structs are instantiated (but before first use)
|
||||
// and all structs with a reference to this struct will see the same update.
|
||||
type gatewayResolverDep struct {
|
||||
// GatewayResolver is a function that returns a suitable random mesh
|
||||
// gateway address for dialing servers in a given DC. This is only
|
||||
// needed if wan federation via mesh gateways is enabled.
|
||||
GatewayResolver func(string) string
|
||||
}
|
||||
|
||||
// TLSWrapper wraps a non-TLS connection and returns a connection with TLS
|
||||
// enabled.
|
||||
type TLSWrapper func(dc string, conn net.Conn) (net.Conn, error)
|
||||
|
||||
// ALPNWrapper is a function that is used to wrap a non-TLS connection and
|
||||
// returns an appropriate TLS connection or error. This taks a datacenter and
|
||||
// node name as argument to configure the desired SNI value and the desired
|
||||
// next proto for configuring ALPN.
|
||||
type ALPNWrapper func(dc, nodeName, alpnProto string, conn net.Conn) (net.Conn, error)
|
||||
|
||||
type dialer func(context.Context, string) (net.Conn, error)
|
||||
|
||||
// NewClientConnPool create new GRPC client pool to connect to servers using GRPC over RPC
|
||||
func NewClientConnPool(servers ServerLocator, tls TLSWrapper, useTLSForDC func(dc string) bool) *ClientConnPool {
|
||||
return &ClientConnPool{
|
||||
dialer: newDialer(servers, tls, useTLSForDC),
|
||||
servers: servers,
|
||||
type ClientConnPoolConfig struct {
|
||||
// Servers is a reference for how to figure out how to dial any server.
|
||||
Servers ServerLocator
|
||||
|
||||
// SrcAddr is the source address for outgoing connections.
|
||||
SrcAddr *net.TCPAddr
|
||||
|
||||
// TLSWrapper is the specifics of wrapping a socket when doing an TYPE_BYTE+TLS
|
||||
// wrapped RPC request.
|
||||
TLSWrapper TLSWrapper
|
||||
|
||||
// ALPNWrapper is the specifics of wrapping a socket when doing an ALPN+TLS
|
||||
// wrapped RPC request (typically only for wan federation via mesh
|
||||
// gateways).
|
||||
ALPNWrapper ALPNWrapper
|
||||
|
||||
// UseTLSForDC is a function to determine if dialing a given datacenter
|
||||
// should use TLS.
|
||||
UseTLSForDC func(dc string) bool
|
||||
|
||||
// DialingFromServer should be set to true if this connection pool is owned
|
||||
// by a consul server instance.
|
||||
DialingFromServer bool
|
||||
|
||||
// DialingFromDatacenter is the datacenter of the consul agent using this
|
||||
// pool.
|
||||
DialingFromDatacenter string
|
||||
}
|
||||
|
||||
// NewClientConnPool create new GRPC client pool to connect to servers using
|
||||
// GRPC over RPC.
|
||||
func NewClientConnPool(cfg ClientConnPoolConfig) *ClientConnPool {
|
||||
c := &ClientConnPool{
|
||||
servers: cfg.Servers,
|
||||
conns: make(map[string]*grpc.ClientConn),
|
||||
}
|
||||
c.dialer = newDialer(cfg, &c.gwResolverDep)
|
||||
return c
|
||||
}
|
||||
|
||||
// SetGatewayResolver is only to be called during setup before the pool is used.
|
||||
func (c *ClientConnPool) SetGatewayResolver(gatewayResolver func(string) string) {
|
||||
c.gwResolverDep.GatewayResolver = gatewayResolver
|
||||
}
|
||||
|
||||
// ClientConn returns a grpc.ClientConn for the datacenter. If there are no
|
||||
|
@ -102,22 +157,39 @@ func (c *ClientConnPool) dial(datacenter string, serverType string) (*grpc.Clien
|
|||
|
||||
// newDialer returns a gRPC dialer function that conditionally wraps the connection
|
||||
// with TLS based on the Server.useTLS value.
|
||||
func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc string) bool) func(context.Context, string) (net.Conn, error) {
|
||||
return func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
d := net.Dialer{}
|
||||
conn, err := d.DialContext(ctx, "tcp", addr)
|
||||
func newDialer(cfg ClientConnPoolConfig, gwResolverDep *gatewayResolverDep) func(context.Context, string) (net.Conn, error) {
|
||||
return func(ctx context.Context, globalAddr string) (net.Conn, error) {
|
||||
server, err := cfg.Servers.ServerForGlobalAddr(globalAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
server, err := servers.ServerForAddr(addr)
|
||||
if cfg.DialingFromServer &&
|
||||
gwResolverDep.GatewayResolver != nil &&
|
||||
cfg.ALPNWrapper != nil &&
|
||||
server.Datacenter != cfg.DialingFromDatacenter {
|
||||
// NOTE: TLS is required on this branch.
|
||||
conn, _, err := pool.DialRPCViaMeshGateway(
|
||||
ctx,
|
||||
server.Datacenter,
|
||||
server.ShortName,
|
||||
cfg.SrcAddr,
|
||||
tlsutil.ALPNWrapper(cfg.ALPNWrapper),
|
||||
pool.ALPN_RPCGRPC,
|
||||
cfg.DialingFromServer,
|
||||
gwResolverDep.GatewayResolver,
|
||||
)
|
||||
return conn, err
|
||||
}
|
||||
|
||||
d := net.Dialer{LocalAddr: cfg.SrcAddr, Timeout: pool.DefaultDialTimeout}
|
||||
conn, err := d.DialContext(ctx, "tcp", server.Addr.String())
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if server.UseTLS && useTLSForDC(server.Datacenter) {
|
||||
if wrapper == nil {
|
||||
if server.UseTLS && cfg.UseTLSForDC(server.Datacenter) {
|
||||
if cfg.TLSWrapper == nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper")
|
||||
}
|
||||
|
@ -129,7 +201,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc st
|
|||
}
|
||||
|
||||
// Wrap the connection in a TLS client
|
||||
tlsConn, err := wrapper(server.Datacenter, conn)
|
||||
tlsConn, err := cfg.TLSWrapper(server.Datacenter, conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
|
@ -137,7 +209,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc st
|
|||
conn = tlsConn
|
||||
}
|
||||
|
||||
_, err = conn.Write([]byte{pool.RPCGRPC})
|
||||
_, err = conn.Write([]byte{byte(pool.RPCGRPC)})
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
|
|
|
@ -4,17 +4,22 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/tcpproxy"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/hashicorp/consul/agent/grpc/internal/testservice"
|
||||
"github.com/hashicorp/consul/agent/grpc/resolver"
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/ipaddr"
|
||||
"github.com/hashicorp/consul/sdk/freeport"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
)
|
||||
|
||||
|
@ -24,11 +29,14 @@ func useTLSForDcAlwaysTrue(_ string) bool {
|
|||
}
|
||||
|
||||
func TestNewDialer_WithTLSWrapper(t *testing.T) {
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
ports := freeport.MustTake(1)
|
||||
defer freeport.Return(ports)
|
||||
|
||||
lis, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(ports[0])))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(logError(t, lis.Close))
|
||||
|
||||
builder := resolver.NewServerResolverBuilder(resolver.Config{})
|
||||
builder := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
builder.AddServer(&metadata.Server{
|
||||
Name: "server-1",
|
||||
ID: "ID1",
|
||||
|
@ -42,19 +50,107 @@ func TestNewDialer_WithTLSWrapper(t *testing.T) {
|
|||
called = true
|
||||
return conn, nil
|
||||
}
|
||||
dial := newDialer(builder, wrapper, useTLSForDcAlwaysTrue)
|
||||
dial := newDialer(
|
||||
ClientConnPoolConfig{
|
||||
Servers: builder,
|
||||
TLSWrapper: wrapper,
|
||||
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
},
|
||||
&gatewayResolverDep{},
|
||||
)
|
||||
ctx := context.Background()
|
||||
conn, err := dial(ctx, lis.Addr().String())
|
||||
conn, err := dial(ctx, resolver.DCPrefix("dc1", lis.Addr().String()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close())
|
||||
require.True(t, called, "expected TLSWrapper to be called")
|
||||
}
|
||||
|
||||
func TestNewDialer_WithALPNWrapper(t *testing.T) {
|
||||
ports := freeport.MustTake(3)
|
||||
defer freeport.Return(ports)
|
||||
|
||||
var (
|
||||
s1addr = ipaddr.FormatAddressPort("127.0.0.1", ports[0])
|
||||
s2addr = ipaddr.FormatAddressPort("127.0.0.1", ports[1])
|
||||
gwAddr = ipaddr.FormatAddressPort("127.0.0.1", ports[2])
|
||||
)
|
||||
|
||||
lis1, err := net.Listen("tcp", s1addr)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(logError(t, lis1.Close))
|
||||
|
||||
lis2, err := net.Listen("tcp", s2addr)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(logError(t, lis2.Close))
|
||||
|
||||
// Send all of the traffic to dc2's server
|
||||
var p tcpproxy.Proxy
|
||||
p.AddRoute(gwAddr, tcpproxy.To(s2addr))
|
||||
p.AddStopACMESearch(gwAddr)
|
||||
require.NoError(t, p.Start())
|
||||
defer func() {
|
||||
p.Close()
|
||||
p.Wait()
|
||||
}()
|
||||
|
||||
builder := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
builder.AddServer(&metadata.Server{
|
||||
Name: "server-1",
|
||||
ID: "ID1",
|
||||
Datacenter: "dc1",
|
||||
Addr: lis1.Addr(),
|
||||
UseTLS: true,
|
||||
})
|
||||
builder.AddServer(&metadata.Server{
|
||||
Name: "server-2",
|
||||
ID: "ID2",
|
||||
Datacenter: "dc2",
|
||||
Addr: lis2.Addr(),
|
||||
UseTLS: true,
|
||||
})
|
||||
|
||||
var calledTLS bool
|
||||
wrapperTLS := func(_ string, conn net.Conn) (net.Conn, error) {
|
||||
calledTLS = true
|
||||
return conn, nil
|
||||
}
|
||||
var calledALPN bool
|
||||
wrapperALPN := func(_, _, _ string, conn net.Conn) (net.Conn, error) {
|
||||
calledALPN = true
|
||||
return conn, nil
|
||||
}
|
||||
gwResolverDep := &gatewayResolverDep{
|
||||
GatewayResolver: func(addr string) string {
|
||||
return gwAddr
|
||||
},
|
||||
}
|
||||
dial := newDialer(
|
||||
ClientConnPoolConfig{
|
||||
Servers: builder,
|
||||
TLSWrapper: wrapperTLS,
|
||||
ALPNWrapper: wrapperALPN,
|
||||
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
},
|
||||
gwResolverDep,
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := dial(ctx, resolver.DCPrefix("dc2", lis2.Addr().String()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close())
|
||||
|
||||
assert.False(t, calledTLS, "expected TLSWrapper not to be called")
|
||||
assert.True(t, calledALPN, "expected ALPNWrapper to be called")
|
||||
}
|
||||
|
||||
func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
|
||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
registerWithGRPC(t, res)
|
||||
|
||||
srv := newTestServer(t, "server-1", "dc1")
|
||||
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
|
||||
VerifyIncoming: true,
|
||||
VerifyOutgoing: true,
|
||||
|
@ -63,12 +159,20 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
|
|||
KeyFile: "../../test/hostname/Alice.key",
|
||||
}, hclog.New(nil))
|
||||
require.NoError(t, err)
|
||||
srv.rpc.tlsConf = tlsConf
|
||||
|
||||
res.AddServer(srv.Metadata())
|
||||
srv := newTestServer(t, "server-1", "dc1", tlsConf)
|
||||
|
||||
md := srv.Metadata()
|
||||
res.AddServer(md)
|
||||
t.Cleanup(srv.shutdown)
|
||||
|
||||
pool := NewClientConnPool(res, TLSWrapper(tlsConf.OutgoingRPCWrapper()), tlsConf.UseTLS)
|
||||
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||
Servers: res,
|
||||
TLSWrapper: TLSWrapper(tlsConf.OutgoingRPCWrapper()),
|
||||
UseTLSForDC: tlsConf.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
})
|
||||
|
||||
conn, err := pool.ClientConn("dc1")
|
||||
require.NoError(t, err)
|
||||
|
@ -81,17 +185,98 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, "server-1", resp.ServerName)
|
||||
require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) > 0)
|
||||
require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) == 0)
|
||||
}
|
||||
|
||||
func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) {
|
||||
ports := freeport.MustTake(1)
|
||||
defer freeport.Return(ports)
|
||||
|
||||
gwAddr := ipaddr.FormatAddressPort("127.0.0.1", ports[0])
|
||||
|
||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
registerWithGRPC(t, res)
|
||||
|
||||
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
|
||||
VerifyIncoming: true,
|
||||
VerifyOutgoing: true,
|
||||
VerifyServerHostname: true,
|
||||
CAFile: "../../test/hostname/CertAuth.crt",
|
||||
CertFile: "../../test/hostname/Bob.crt",
|
||||
KeyFile: "../../test/hostname/Bob.key",
|
||||
Domain: "consul",
|
||||
NodeName: "bob",
|
||||
}, hclog.New(nil))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := newTestServer(t, "bob", "dc1", tlsConf)
|
||||
|
||||
// Send all of the traffic to dc1's server
|
||||
var p tcpproxy.Proxy
|
||||
p.AddRoute(gwAddr, tcpproxy.To(srv.addr.String()))
|
||||
p.AddStopACMESearch(gwAddr)
|
||||
require.NoError(t, p.Start())
|
||||
defer func() {
|
||||
p.Close()
|
||||
p.Wait()
|
||||
}()
|
||||
|
||||
md := srv.Metadata()
|
||||
res.AddServer(md)
|
||||
t.Cleanup(srv.shutdown)
|
||||
|
||||
clientTLSConf, err := tlsutil.NewConfigurator(tlsutil.Config{
|
||||
VerifyIncoming: true,
|
||||
VerifyOutgoing: true,
|
||||
VerifyServerHostname: true,
|
||||
CAFile: "../../test/hostname/CertAuth.crt",
|
||||
CertFile: "../../test/hostname/Betty.crt",
|
||||
KeyFile: "../../test/hostname/Betty.key",
|
||||
Domain: "consul",
|
||||
NodeName: "betty",
|
||||
}, hclog.New(nil))
|
||||
require.NoError(t, err)
|
||||
|
||||
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||
Servers: res,
|
||||
TLSWrapper: TLSWrapper(clientTLSConf.OutgoingRPCWrapper()),
|
||||
ALPNWrapper: ALPNWrapper(clientTLSConf.OutgoingALPNRPCWrapper()),
|
||||
UseTLSForDC: tlsConf.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc2",
|
||||
})
|
||||
pool.SetGatewayResolver(func(addr string) string {
|
||||
return gwAddr
|
||||
})
|
||||
|
||||
conn, err := pool.ClientConn("dc1")
|
||||
require.NoError(t, err)
|
||||
client := testservice.NewSimpleClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
resp, err := client.Something(ctx, &testservice.Req{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "bob", resp.ServerName)
|
||||
require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) == 0)
|
||||
require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) > 0)
|
||||
}
|
||||
|
||||
func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
|
||||
count := 4
|
||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
registerWithGRPC(t, res)
|
||||
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)
|
||||
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||
Servers: res,
|
||||
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
})
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
name := fmt.Sprintf("server-%d", i)
|
||||
srv := newTestServer(t, name, "dc1")
|
||||
srv := newTestServer(t, name, "dc1", nil)
|
||||
res.AddServer(srv.Metadata())
|
||||
t.Cleanup(srv.shutdown)
|
||||
}
|
||||
|
@ -115,22 +300,27 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
|
|||
|
||||
func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
|
||||
count := 3
|
||||
conf := newConfig(t)
|
||||
res := resolver.NewServerResolverBuilder(conf)
|
||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
registerWithGRPC(t, res)
|
||||
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)
|
||||
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||
Servers: res,
|
||||
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
})
|
||||
|
||||
var servers []testServer
|
||||
for i := 0; i < count; i++ {
|
||||
name := fmt.Sprintf("server-%d", i)
|
||||
srv := newTestServer(t, name, "dc1")
|
||||
srv := newTestServer(t, name, "dc1", nil)
|
||||
res.AddServer(srv.Metadata())
|
||||
servers = append(servers, srv)
|
||||
t.Cleanup(srv.shutdown)
|
||||
}
|
||||
|
||||
// Set the leader address to the first server.
|
||||
res.UpdateLeaderAddr(servers[0].addr.String())
|
||||
srv0 := servers[0].Metadata()
|
||||
res.UpdateLeaderAddr(srv0.Datacenter, srv0.Addr.String())
|
||||
|
||||
conn, err := pool.ClientConnLeader()
|
||||
require.NoError(t, err)
|
||||
|
@ -144,7 +334,8 @@ func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
|
|||
require.Equal(t, first.ServerName, servers[0].name)
|
||||
|
||||
// Update the leader address and make another request.
|
||||
res.UpdateLeaderAddr(servers[1].addr.String())
|
||||
srv1 := servers[1].Metadata()
|
||||
res.UpdateLeaderAddr(srv1.Datacenter, srv1.Addr.String())
|
||||
|
||||
resp, err := client.Something(ctx, &testservice.Req{})
|
||||
require.NoError(t, err)
|
||||
|
@ -162,11 +353,16 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) {
|
|||
count := 5
|
||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
registerWithGRPC(t, res)
|
||||
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)
|
||||
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||
Servers: res,
|
||||
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
})
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
name := fmt.Sprintf("server-%d", i)
|
||||
srv := newTestServer(t, name, "dc1")
|
||||
srv := newTestServer(t, name, "dc1", nil)
|
||||
res.AddServer(srv.Metadata())
|
||||
t.Cleanup(srv.shutdown)
|
||||
}
|
||||
|
@ -211,11 +407,16 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
|
|||
|
||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
registerWithGRPC(t, res)
|
||||
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)
|
||||
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||
Servers: res,
|
||||
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
})
|
||||
|
||||
for _, dc := range dcs {
|
||||
name := "server-0-" + dc
|
||||
srv := newTestServer(t, name, dc)
|
||||
srv := newTestServer(t, name, dc, nil)
|
||||
res.AddServer(srv.Metadata())
|
||||
t.Cleanup(srv.shutdown)
|
||||
}
|
||||
|
|
|
@ -67,17 +67,17 @@ func (s *ServerResolverBuilder) NewRebalancer(dc string) func() {
|
|||
}
|
||||
}
|
||||
|
||||
// ServerForAddr returns server metadata for a server with the specified address.
|
||||
func (s *ServerResolverBuilder) ServerForAddr(addr string) (*metadata.Server, error) {
|
||||
// ServerForGlobalAddr returns server metadata for a server with the specified globally unique address.
|
||||
func (s *ServerResolverBuilder) ServerForGlobalAddr(globalAddr string) (*metadata.Server, error) {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
for _, server := range s.servers {
|
||||
if server.Addr.String() == addr {
|
||||
if DCPrefix(server.Datacenter, server.Addr.String()) == globalAddr {
|
||||
return server, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find Consul server for address %q", addr)
|
||||
return nil, fmt.Errorf("failed to find Consul server for global address %q", globalAddr)
|
||||
}
|
||||
|
||||
// Build returns a new serverResolver for the given ClientConn. The resolver
|
||||
|
@ -161,6 +161,12 @@ func uniqueID(server *metadata.Server) string {
|
|||
return server.Datacenter + "-" + server.ID
|
||||
}
|
||||
|
||||
// DCPrefix prefixes the given string with a datacenter for use in
|
||||
// disambiguation.
|
||||
func DCPrefix(datacenter, suffix string) string {
|
||||
return datacenter + "-" + suffix
|
||||
}
|
||||
|
||||
// RemoveServer updates the resolvers' states with the given server removed.
|
||||
func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) {
|
||||
s.lock.Lock()
|
||||
|
@ -186,7 +192,8 @@ func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address {
|
|||
}
|
||||
|
||||
addrs = append(addrs, resolver.Address{
|
||||
Addr: server.Addr.String(),
|
||||
// NOTE: the address persisted here is only dialable using our custom dialer
|
||||
Addr: DCPrefix(server.Datacenter, server.Addr.String()),
|
||||
Type: resolver.Backend,
|
||||
ServerName: server.Name,
|
||||
})
|
||||
|
@ -195,11 +202,11 @@ func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address {
|
|||
}
|
||||
|
||||
// UpdateLeaderAddr updates the leader address in the local DC's resolver.
|
||||
func (s *ServerResolverBuilder) UpdateLeaderAddr(leaderAddr string) {
|
||||
func (s *ServerResolverBuilder) UpdateLeaderAddr(datacenter, addr string) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
s.leaderResolver.addr = leaderAddr
|
||||
s.leaderResolver.globalAddr = DCPrefix(datacenter, addr)
|
||||
s.leaderResolver.updateClientConn()
|
||||
}
|
||||
|
||||
|
@ -262,7 +269,7 @@ func (r *serverResolver) Close() {
|
|||
func (*serverResolver) ResolveNow(resolver.ResolveNowOption) {}
|
||||
|
||||
type leaderResolver struct {
|
||||
addr string
|
||||
globalAddr string
|
||||
clientConn resolver.ClientConn
|
||||
}
|
||||
|
||||
|
@ -271,12 +278,13 @@ func (l leaderResolver) ResolveNow(resolver.ResolveNowOption) {}
|
|||
func (l leaderResolver) Close() {}
|
||||
|
||||
func (l leaderResolver) updateClientConn() {
|
||||
if l.addr == "" || l.clientConn == nil {
|
||||
if l.globalAddr == "" || l.clientConn == nil {
|
||||
return
|
||||
}
|
||||
addrs := []resolver.Address{
|
||||
{
|
||||
Addr: l.addr,
|
||||
// NOTE: the address persisted here is only dialable using our custom dialer
|
||||
Addr: l.globalAddr,
|
||||
Type: resolver.Backend,
|
||||
ServerName: "leader",
|
||||
},
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -17,6 +18,7 @@ import (
|
|||
"github.com/hashicorp/consul/agent/grpc/internal/testservice"
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/agent/pool"
|
||||
"github.com/hashicorp/consul/sdk/freeport"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
)
|
||||
|
||||
|
@ -31,22 +33,29 @@ type testServer struct {
|
|||
func (s testServer) Metadata() *metadata.Server {
|
||||
return &metadata.Server{
|
||||
ID: s.name,
|
||||
Name: s.name + "." + s.dc,
|
||||
ShortName: s.name,
|
||||
Datacenter: s.dc,
|
||||
Addr: s.addr,
|
||||
UseTLS: s.rpc.tlsConf != nil,
|
||||
}
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, name string, dc string) testServer {
|
||||
func newTestServer(t *testing.T, name string, dc string, tlsConf *tlsutil.Configurator) testServer {
|
||||
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
|
||||
handler := NewHandler(addr, func(server *grpc.Server) {
|
||||
testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc})
|
||||
})
|
||||
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
ports := freeport.MustTake(1)
|
||||
t.Cleanup(func() {
|
||||
freeport.Return(ports)
|
||||
})
|
||||
|
||||
lis, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(ports[0])))
|
||||
require.NoError(t, err)
|
||||
|
||||
rpc := &fakeRPCListener{t: t, handler: handler}
|
||||
rpc := &fakeRPCListener{t: t, handler: handler, tlsConf: tlsConf}
|
||||
|
||||
g := errgroup.Group{}
|
||||
g.Go(func() error {
|
||||
|
@ -107,11 +116,12 @@ func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice.
|
|||
// For now, since this logic is in agent/consul, we can't easily use Server.listen
|
||||
// so we fake it.
|
||||
type fakeRPCListener struct {
|
||||
t *testing.T
|
||||
handler *Handler
|
||||
shutdown bool
|
||||
tlsConf *tlsutil.Configurator
|
||||
tlsConnEstablished int32
|
||||
t *testing.T
|
||||
handler *Handler
|
||||
shutdown bool
|
||||
tlsConf *tlsutil.Configurator
|
||||
tlsConnEstablished int32
|
||||
alpnConnEstablished int32
|
||||
}
|
||||
|
||||
func (f *fakeRPCListener) listen(listener net.Listener) error {
|
||||
|
@ -129,6 +139,26 @@ func (f *fakeRPCListener) listen(listener net.Listener) error {
|
|||
}
|
||||
|
||||
func (f *fakeRPCListener) handleConn(conn net.Conn) {
|
||||
if f.tlsConf != nil && f.tlsConf.MutualTLSCapable() {
|
||||
// See if actually this is native TLS multiplexed onto the old
|
||||
// "type-byte" system.
|
||||
|
||||
peekedConn, nativeTLS, err := pool.PeekForTLS(conn)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
fmt.Printf("ERROR: failed to read first byte: %v\n", err)
|
||||
}
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if nativeTLS {
|
||||
f.handleNativeTLSConn(peekedConn)
|
||||
return
|
||||
}
|
||||
conn = peekedConn
|
||||
}
|
||||
|
||||
buf := make([]byte, 1)
|
||||
|
||||
if _, err := conn.Read(buf); err != nil {
|
||||
|
@ -166,3 +196,32 @@ func (f *fakeRPCListener) handleConn(conn net.Conn) {
|
|||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeRPCListener) handleNativeTLSConn(conn net.Conn) {
|
||||
tlscfg := f.tlsConf.IncomingALPNRPCConfig(pool.RPCNextProtos)
|
||||
tlsConn := tls.Server(conn, tlscfg)
|
||||
|
||||
// Force the handshake to conclude.
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
fmt.Printf("ERROR: TLS handshake failed: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
var (
|
||||
cs = tlsConn.ConnectionState()
|
||||
nextProto = cs.NegotiatedProtocol
|
||||
)
|
||||
|
||||
switch nextProto {
|
||||
case pool.ALPN_RPCGRPC:
|
||||
atomic.AddInt32(&f.alpnConnEstablished, 1)
|
||||
f.handler.Handle(tlsConn)
|
||||
|
||||
default:
|
||||
fmt.Printf("ERROR: discarding RPC for unknown negotiated protocol %q\n", nextProto)
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@ func (t RPCType) ALPNString() string {
|
|||
return ALPN_RPCGossip
|
||||
case RPCTLSInsecure:
|
||||
return "" // unsupported
|
||||
case RPCGRPC:
|
||||
return ALPN_RPCGRPC
|
||||
default:
|
||||
return "" // unsupported
|
||||
}
|
||||
|
@ -28,19 +30,19 @@ func (t RPCType) ALPNString() string {
|
|||
const (
|
||||
// keep numbers unique.
|
||||
RPCConsul RPCType = 0
|
||||
RPCRaft = 1
|
||||
RPCMultiplex = 2 // Old Muxado byte, no longer supported.
|
||||
RPCTLS = 3
|
||||
RPCMultiplexV2 = 4
|
||||
RPCSnapshot = 5
|
||||
RPCGossip = 6
|
||||
RPCRaft RPCType = 1
|
||||
RPCMultiplex RPCType = 2 // Old Muxado byte, no longer supported.
|
||||
RPCTLS RPCType = 3
|
||||
RPCMultiplexV2 RPCType = 4
|
||||
RPCSnapshot RPCType = 5
|
||||
RPCGossip RPCType = 6
|
||||
// RPCTLSInsecure is used to flag RPC calls that require verify
|
||||
// incoming to be disabled, even when it is turned on in the
|
||||
// configuration. At the time of writing there is only AutoEncrypt.Sign
|
||||
// that is supported and it might be the only one there
|
||||
// ever is.
|
||||
RPCTLSInsecure = 7
|
||||
RPCGRPC = 8
|
||||
RPCTLSInsecure RPCType = 7
|
||||
RPCGRPC RPCType = 8
|
||||
|
||||
// RPCMaxTypeValue is the maximum rpc type byte value currently used for the
|
||||
// various protocols riding over our "rpc" port.
|
||||
|
@ -79,6 +81,7 @@ var RPCNextProtos = []string{
|
|||
ALPN_RPCMultiplexV2,
|
||||
ALPN_RPCSnapshot,
|
||||
ALPN_RPCGossip,
|
||||
ALPN_RPCGRPC,
|
||||
ALPN_WANGossipPacket,
|
||||
ALPN_WANGossipStream,
|
||||
}
|
||||
|
|
|
@ -10,8 +10,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
)
|
||||
|
||||
func TestPeekForTLS_not_TLS(t *testing.T) {
|
||||
|
@ -30,6 +31,7 @@ func TestPeekForTLS_not_TLS(t *testing.T) {
|
|||
RPCSnapshot,
|
||||
RPCGossip,
|
||||
RPCTLSInsecure,
|
||||
RPCGRPC,
|
||||
} {
|
||||
cases = append(cases, testcase{
|
||||
name: fmt.Sprintf("tcp rpc type byte %d", rpcType),
|
||||
|
@ -76,6 +78,7 @@ func TestPeekForTLS_actual_TLS(t *testing.T) {
|
|||
RPCSnapshot,
|
||||
RPCGossip,
|
||||
RPCTLSInsecure,
|
||||
RPCGRPC,
|
||||
} {
|
||||
cases = append(cases, testcase{
|
||||
name: fmt.Sprintf("tcp rpc type byte %d", rpcType),
|
||||
|
|
|
@ -2,6 +2,7 @@ package pool
|
|||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
|
@ -11,14 +12,15 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
const defaultDialTimeout = 10 * time.Second
|
||||
const DefaultDialTimeout = 10 * time.Second
|
||||
|
||||
// muxSession is used to provide an interface for a stream multiplexer.
|
||||
type muxSession interface {
|
||||
|
@ -291,21 +293,24 @@ func (p *ConnPool) DialTimeout(
|
|||
) (net.Conn, HalfCloser, error) {
|
||||
p.once.Do(p.init)
|
||||
|
||||
if p.Server && p.GatewayResolver != nil && p.TLSConfigurator != nil && dc != p.Datacenter {
|
||||
if p.Server &&
|
||||
p.GatewayResolver != nil &&
|
||||
p.TLSConfigurator != nil &&
|
||||
dc != p.Datacenter {
|
||||
// NOTE: TLS is required on this branch.
|
||||
return DialTimeoutWithRPCTypeViaMeshGateway(
|
||||
nextProto := actualRPCType.ALPNString()
|
||||
if nextProto == "" {
|
||||
return nil, nil, fmt.Errorf("rpc type %d cannot be routed through a mesh gateway", actualRPCType)
|
||||
}
|
||||
return DialRPCViaMeshGateway(
|
||||
context.Background(),
|
||||
dc,
|
||||
nodeName,
|
||||
addr,
|
||||
p.SrcAddr,
|
||||
p.TLSConfigurator.OutgoingALPNRPCWrapper(),
|
||||
actualRPCType,
|
||||
RPCTLS,
|
||||
// gateway stuff
|
||||
nextProto,
|
||||
p.Server,
|
||||
p.TLSConfigurator,
|
||||
p.GatewayResolver,
|
||||
p.Datacenter,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -319,7 +324,7 @@ func (p *ConnPool) dial(
|
|||
tlsRPCType RPCType,
|
||||
) (net.Conn, HalfCloser, error) {
|
||||
// Try to dial the conn
|
||||
d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: defaultDialTimeout}
|
||||
d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: DefaultDialTimeout}
|
||||
conn, err := d.Dial("tcp", addr.String())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -372,62 +377,49 @@ func (p *ConnPool) dial(
|
|||
return conn, hc, nil
|
||||
}
|
||||
|
||||
// DialTimeoutWithRPCTypeViaMeshGateway dials the destination node and sets up
|
||||
// the connection to be the correct RPC type using ALPN. This currently is
|
||||
// exclusively used to dial other servers in foreign datacenters via mesh
|
||||
// gateways.
|
||||
//
|
||||
// NOTE: There is a close mirror of this method in agent/consul/wanfed/wanfed.go:dial
|
||||
func DialTimeoutWithRPCTypeViaMeshGateway(
|
||||
dc string,
|
||||
nodeName string,
|
||||
addr net.Addr,
|
||||
src *net.TCPAddr,
|
||||
wrapper tlsutil.ALPNWrapper,
|
||||
actualRPCType RPCType,
|
||||
tlsRPCType RPCType,
|
||||
// gateway stuff
|
||||
// DialRPCViaMeshGateway dials the destination node and sets up the connection
|
||||
// to be the correct RPC type using ALPN. This currently is exclusively used to
|
||||
// dial other servers in foreign datacenters via mesh gateways.
|
||||
func DialRPCViaMeshGateway(
|
||||
ctx context.Context,
|
||||
dc string, // (metadata.Server).Datacenter
|
||||
nodeName string, // (metadata.Server).ShortName
|
||||
srcAddr *net.TCPAddr,
|
||||
alpnWrapper tlsutil.ALPNWrapper,
|
||||
nextProto string,
|
||||
dialingFromServer bool,
|
||||
tlsConfigurator *tlsutil.Configurator,
|
||||
gatewayResolver func(string) string,
|
||||
thisDatacenter string,
|
||||
) (net.Conn, HalfCloser, error) {
|
||||
if !dialingFromServer {
|
||||
return nil, nil, fmt.Errorf("must dial via mesh gateways from a server agent")
|
||||
} else if gatewayResolver == nil {
|
||||
return nil, nil, fmt.Errorf("gatewayResolver is nil")
|
||||
} else if tlsConfigurator == nil {
|
||||
return nil, nil, fmt.Errorf("tlsConfigurator is nil")
|
||||
} else if dc == thisDatacenter {
|
||||
return nil, nil, fmt.Errorf("cannot dial servers in the same datacenter via a mesh gateway")
|
||||
} else if wrapper == nil {
|
||||
} else if alpnWrapper == nil {
|
||||
return nil, nil, fmt.Errorf("cannot dial via a mesh gateway when outgoing TLS is disabled")
|
||||
}
|
||||
|
||||
nextProto := actualRPCType.ALPNString()
|
||||
if nextProto == "" {
|
||||
return nil, nil, fmt.Errorf("rpc type %d cannot be routed through a mesh gateway", actualRPCType)
|
||||
}
|
||||
|
||||
gwAddr := gatewayResolver(dc)
|
||||
if gwAddr == "" {
|
||||
return nil, nil, structs.ErrDCNotAvailable
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{LocalAddr: src, Timeout: defaultDialTimeout}
|
||||
dialer := &net.Dialer{LocalAddr: srcAddr, Timeout: DefaultDialTimeout}
|
||||
|
||||
rawConn, err := dialer.Dial("tcp", gwAddr)
|
||||
rawConn, err := dialer.DialContext(ctx, "tcp", gwAddr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if tcp, ok := rawConn.(*net.TCPConn); ok {
|
||||
_ = tcp.SetKeepAlive(true)
|
||||
_ = tcp.SetNoDelay(true)
|
||||
if nextProto != ALPN_RPCGRPC {
|
||||
// agent/grpc/client.go:dial() handles this in another way for gRPC
|
||||
if tcp, ok := rawConn.(*net.TCPConn); ok {
|
||||
_ = tcp.SetKeepAlive(true)
|
||||
_ = tcp.SetNoDelay(true)
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: now we wrap the connection in a TLS client.
|
||||
tlsConn, err := wrapper(dc, nodeName, nextProto, rawConn)
|
||||
tlsConn, err := alpnWrapper(dc, nodeName, nextProto, rawConn)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
|
@ -106,9 +106,22 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error)
|
|||
d.ViewStore = submatview.NewStore(d.Logger.Named("viewstore"))
|
||||
d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator)
|
||||
|
||||
builder := resolver.NewServerResolverBuilder(resolver.Config{})
|
||||
builder := resolver.NewServerResolverBuilder(resolver.Config{
|
||||
// Set the authority to something sufficiently unique so any usage in
|
||||
// tests would be self-isolating in the global resolver map, while also
|
||||
// not incurring a huge penalty for non-test code.
|
||||
Authority: cfg.Datacenter + "." + string(cfg.NodeID),
|
||||
})
|
||||
resolver.Register(builder)
|
||||
d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()), d.TLSConfigurator.UseTLS)
|
||||
d.GRPCConnPool = grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
|
||||
Servers: builder,
|
||||
SrcAddr: d.ConnPool.SrcAddr,
|
||||
TLSWrapper: grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()),
|
||||
ALPNWrapper: grpc.ALPNWrapper(d.TLSConfigurator.OutgoingALPNRPCWrapper()),
|
||||
UseTLSForDC: d.TLSConfigurator.UseTLS,
|
||||
DialingFromServer: cfg.ServerMode,
|
||||
DialingFromDatacenter: cfg.Datacenter,
|
||||
})
|
||||
d.LeaderForwarder = builder
|
||||
|
||||
d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder)
|
||||
|
|
Loading…
Reference in New Issue