grpc: fix data race in balancer registration (#16229)

Registering gRPC balancers is thread-unsafe because they are stored in a
global map variable that is accessed without holding a lock. Therefore,
it's expected that balancers are registered _once_ at the beginning of
your program (e.g. in a package `init` function) and certainly not after
you've started dialing connections, etc.

> NOTE: this function must only be called during initialization time
> (i.e. in an init() function), and is not thread-safe.

While this is fine for us in production, it's challenging for tests that
spin up multiple agents in-memory. We currently register a balancer per-
agent which holds agent-specific state that cannot safely be shared.

This commit introduces our own registry that _is_ thread-safe, and
implements the Builder interface such that we can call gRPC's `Register`
method once, on start-up. It uses the same pattern as our resolver
registry where we use the dial target's host (aka "authority"), which is
unique per-agent, to determine which builder to use.
This commit is contained in:
Dan Upton 2023-02-28 10:18:38 +00:00 committed by GitHub
parent 3cbbd63ba2
commit 73b9b407ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 179 additions and 104 deletions

View File

@ -1599,10 +1599,7 @@ func (a *Agent) ShutdownAgent() error {
a.stopLicenseManager() a.stopLicenseManager()
// this would be cancelled anyways (by the closing of the shutdown ch) but a.baseDeps.Close()
// this should help them to be stopped more quickly
a.baseDeps.AutoConfig.Stop()
a.baseDeps.MetricsConfig.Cancel()
a.stateLock.Lock() a.stateLock.Lock()
defer a.stateLock.Unlock() defer a.stateLock.Unlock()

View File

@ -522,9 +522,13 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter)) resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter))
resolver.Register(resolverBuilder) resolver.Register(resolverBuilder)
t.Cleanup(func() {
resolver.Deregister(resolverBuilder.Authority())
})
balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t)) balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t))
balancerBuilder.Register() balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
r := router.NewRouter( r := router.NewRouter(
logger, logger,
@ -559,7 +563,6 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
UseTLSForDC: tls.UseTLS, UseTLSForDC: tls.UseTLS,
DialingFromServer: true, DialingFromServer: true,
DialingFromDatacenter: c.Datacenter, DialingFromDatacenter: c.Datacenter,
BalancerBuilder: balancerBuilder,
}), }),
LeaderForwarder: resolverBuilder, LeaderForwarder: resolverBuilder,
NewRequestRecorderFunc: middleware.NewRequestRecorder, NewRequestRecorderFunc: middleware.NewRequestRecorder,

View File

@ -1165,7 +1165,7 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) {
var conn *grpc.ClientConn var conn *grpc.ClientConn
{ {
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, func(c *Config) { client, resolverBuilder := newClientWithGRPCPlumbing(t, func(c *Config) {
c.Datacenter = "dc2" c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.RPCConfig.EnableStreaming = true c.RPCConfig.EnableStreaming = true
@ -1177,7 +1177,6 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) {
Servers: resolverBuilder, Servers: resolverBuilder,
DialingFromServer: false, DialingFromServer: false,
DialingFromDatacenter: "dc2", DialingFromDatacenter: "dc2",
BalancerBuilder: balancerBuilder,
}) })
conn, err = pool.ClientConn("dc2") conn, err = pool.ClientConn("dc2")

View File

@ -39,7 +39,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer server.Shutdown() defer server.Shutdown()
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing) client, resolverBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
// Try to join // Try to join
testrpc.WaitForLeader(t, server.RPC, "dc1") testrpc.WaitForLeader(t, server.RPC, "dc1")
@ -71,7 +71,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS, UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true, DialingFromServer: true,
DialingFromDatacenter: "dc1", DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
}) })
conn, err := pool.ClientConn("dc1") conn, err := pool.ClientConn("dc1")
require.NoError(t, err) require.NoError(t, err)
@ -116,7 +115,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS, UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true, DialingFromServer: true,
DialingFromDatacenter: "dc1", DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
}) })
conn, err := pool.ClientConn("dc1") conn, err := pool.ClientConn("dc1")
require.NoError(t, err) require.NoError(t, err)
@ -191,7 +189,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
defer server.Shutdown() defer server.Shutdown()
// Set up a client with valid certs and verify_outgoing = true // Set up a client with valid certs and verify_outgoing = true
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing) client, resolverBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
testrpc.WaitForLeader(t, server.RPC, "dc1") testrpc.WaitForLeader(t, server.RPC, "dc1")
@ -204,7 +202,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS, UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true, DialingFromServer: true,
DialingFromDatacenter: "dc1", DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
}) })
conn, err := pool.ClientConn("dc1") conn, err := pool.ClientConn("dc1")
require.NoError(t, err) require.NoError(t, err)
@ -284,7 +281,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
codec := rpcClient(t, server) codec := rpcClient(t, server)
defer codec.Close() defer codec.Close()
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t) client, resolverBuilder := newClientWithGRPCPlumbing(t)
// Try to join // Try to join
testrpc.WaitForLeader(t, server.RPC, "dc1") testrpc.WaitForLeader(t, server.RPC, "dc1")
@ -346,7 +343,6 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
UseTLSForDC: client.tlsConfigurator.UseTLS, UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true, DialingFromServer: true,
DialingFromDatacenter: "dc1", DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
}) })
conn, err := pool.ClientConn("dc1") conn, err := pool.ClientConn("dc1")
require.NoError(t, err) require.NoError(t, err)
@ -376,7 +372,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
"at least some of the subscribers should have received non-snapshot updates") "at least some of the subscribers should have received non-snapshot updates")
} }
func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder, *balancer.Builder) { func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) {
_, config := testClientConfig(t) _, config := testClientConfig(t)
for _, op := range ops { for _, op := range ops {
op(config) op(config)
@ -392,6 +388,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re
balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t)) balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t))
balancerBuilder.Register() balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
deps := newDefaultDeps(t, config) deps := newDefaultDeps(t, config)
deps.Router = router.NewRouter( deps.Router = router.NewRouter(
@ -406,7 +403,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re
t.Cleanup(func() { t.Cleanup(func() {
client.Shutdown() client.Shutdown()
}) })
return client, resolverBuilder, balancerBuilder return client, resolverBuilder
} }
type testLogger interface { type testLogger interface {

View File

@ -65,21 +65,25 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
// NewBuilder constructs a new Builder with the given name. // NewBuilder constructs a new Builder. Calling Register will add the Builder
func NewBuilder(name string, logger hclog.Logger) *Builder { // to our global registry under the given "authority" such that it will be used
// when dialing targets in the form "consul-internal://<authority>/...", this
// allows us to add and remove balancers for different in-memory agents during
// tests.
func NewBuilder(authority string, logger hclog.Logger) *Builder {
return &Builder{ return &Builder{
name: name, authority: authority,
logger: logger, logger: logger,
byTarget: make(map[string]*list.List), byTarget: make(map[string]*list.List),
shuffler: randomShuffler(), shuffler: randomShuffler(),
} }
} }
// Builder implements gRPC's balancer.Builder interface to construct balancers. // Builder implements gRPC's balancer.Builder interface to construct balancers.
type Builder struct { type Builder struct {
name string authority string
logger hclog.Logger logger hclog.Logger
shuffler shuffler shuffler shuffler
mu sync.Mutex mu sync.Mutex
byTarget map[string]*list.List byTarget map[string]*list.List
@ -129,19 +133,15 @@ func (b *Builder) removeBalancer(targetURL string, elem *list.Element) {
} }
} }
// Name implements the gRPC Balancer interface by returning its given name. // Register the Builder in our global registry. Users should call Deregister
func (b *Builder) Name() string { return b.name } // when finished using the Builder to clean-up global state.
// gRPC's balancer.Register method is not thread-safe, so we guard our calls
// with a global lock (as it may be called from parallel tests).
var registerLock sync.Mutex
// Register the Builder in gRPC's global registry using its given name.
func (b *Builder) Register() { func (b *Builder) Register() {
registerLock.Lock() globalRegistry.register(b.authority, b)
defer registerLock.Unlock() }
gbalancer.Register(b) // Deregister the Builder from our global registry to clean up state.
func (b *Builder) Deregister() {
globalRegistry.deregister(b.authority)
} }
// Rebalance randomizes the priority order of servers for the given target to // Rebalance randomizes the priority order of servers for the given target to

View File

@ -21,6 +21,8 @@ import (
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" "github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
"github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/sdk/testutil/retry"
@ -34,12 +36,13 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1") server1 := runServer(t, "server1")
server2 := runServer(t, "server2") server2 := runServer(t, "server2")
target, _ := stubResolver(t, server1, server2) target, authority, _ := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t)) balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register() balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
conn := dial(t, target, balancerBuilder) conn := dial(t, target)
client := testservice.NewSimpleClient(conn) client := testservice.NewSimpleClient(conn)
var serverName string var serverName string
@ -78,12 +81,13 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1") server1 := runServer(t, "server1")
server2 := runServer(t, "server2") server2 := runServer(t, "server2")
target, _ := stubResolver(t, server1, server2) target, authority, _ := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t)) balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register() balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
conn := dial(t, target, balancerBuilder) conn := dial(t, target)
client := testservice.NewSimpleClient(conn) client := testservice.NewSimpleClient(conn)
// Figure out which server we're talking to now, and which we should switch to. // Figure out which server we're talking to now, and which we should switch to.
@ -123,10 +127,11 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1") server1 := runServer(t, "server1")
server2 := runServer(t, "server2") server2 := runServer(t, "server2")
target, _ := stubResolver(t, server1, server2) target, authority, _ := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t)) balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register() balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
// Provide a custom prioritizer that causes Rebalance to choose whichever // Provide a custom prioritizer that causes Rebalance to choose whichever
// server didn't get our first request. // server didn't get our first request.
@ -137,7 +142,7 @@ func TestBalancer(t *testing.T) {
}) })
} }
conn := dial(t, target, balancerBuilder) conn := dial(t, target)
client := testservice.NewSimpleClient(conn) client := testservice.NewSimpleClient(conn)
// Figure out which server we're talking to now. // Figure out which server we're talking to now.
@ -177,12 +182,13 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1") server1 := runServer(t, "server1")
server2 := runServer(t, "server2") server2 := runServer(t, "server2")
target, res := stubResolver(t, server1, server2) target, authority, res := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t)) balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register() balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
conn := dial(t, target, balancerBuilder) conn := dial(t, target)
client := testservice.NewSimpleClient(conn) client := testservice.NewSimpleClient(conn)
// Figure out which server we're talking to now. // Figure out which server we're talking to now.
@ -233,7 +239,7 @@ func TestBalancer(t *testing.T) {
}) })
} }
func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) { func stubResolver(t *testing.T, servers ...*server) (string, string, *manual.Resolver) {
t.Helper() t.Helper()
addresses := make([]resolver.Address, len(servers)) addresses := make([]resolver.Address, len(servers))
@ -249,7 +255,10 @@ func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
resolver.Register(r) resolver.Register(r)
t.Cleanup(func() { resolver.UnregisterForTesting(scheme) }) t.Cleanup(func() { resolver.UnregisterForTesting(scheme) })
return fmt.Sprintf("%s://", scheme), r authority, err := uuid.GenerateUUID()
require.NoError(t, err)
return fmt.Sprintf("%s://%s", scheme, authority), authority, r
} }
func runServer(t *testing.T, name string) *server { func runServer(t *testing.T, name string) *server {
@ -309,12 +318,12 @@ func (s *server) Something(context.Context, *testservice.Req) (*testservice.Resp
return &testservice.Resp{ServerName: s.name}, nil return &testservice.Resp{ServerName: s.name}, nil
} }
func dial(t *testing.T, target string, builder *Builder) *grpc.ClientConn { func dial(t *testing.T, target string) *grpc.ClientConn {
conn, err := grpc.Dial( conn, err := grpc.Dial(
target, target,
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig( grpc.WithDefaultServiceConfig(
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, builder.Name()), fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, BuilderName),
), ),
) )
t.Cleanup(func() { t.Cleanup(func() {

View File

@ -0,0 +1,69 @@
package balancer
import (
"fmt"
"sync"
gbalancer "google.golang.org/grpc/balancer"
)
// BuilderName should be given in gRPC service configuration to enable our
// custom balancer. It refers to this package's global registry, rather than
// an instance of Builder to enable us to add and remove builders at runtime,
// specifically during tests.
const BuilderName = "consul-internal"
// gRPC's balancer.Register method is thread-unsafe because it mutates a global
// map without holding a lock. As such, it's expected that you register custom
// balancers once at the start of your program (e.g. a package init function).
//
// In production, this is fine. Agents register a single instance of our builder
// and use it for the duration. Tests are where this becomes problematic, as we
// spin up several agents in-memory and register/deregister a builder for each,
// with its own agent-specific state, logger, etc.
//
// To avoid data races, we call gRPC's Register method once, on-package init,
// with a global registry struct that implements the Builder interface but
// delegates the building to N instances of our Builder that are registered and
// deregistered at runtime. We the dial target's host (aka "authority") which
// is unique per-agent to pick the correct builder.
func init() {
gbalancer.Register(globalRegistry)
}
var globalRegistry = &registry{
byAuthority: make(map[string]*Builder),
}
type registry struct {
mu sync.RWMutex
byAuthority map[string]*Builder
}
func (r *registry) Build(cc gbalancer.ClientConn, opts gbalancer.BuildOptions) gbalancer.Balancer {
r.mu.RLock()
defer r.mu.RUnlock()
auth := opts.Target.URL.Host
builder, ok := r.byAuthority[auth]
if !ok {
panic(fmt.Sprintf("no gRPC balancer builder registered for authority: %q", auth))
}
return builder.Build(cc, opts)
}
func (r *registry) Name() string { return BuilderName }
func (r *registry) register(auth string, builder *Builder) {
r.mu.Lock()
defer r.mu.Unlock()
r.byAuthority[auth] = builder
}
func (r *registry) deregister(auth string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.byAuthority, auth)
}

View File

@ -8,12 +8,12 @@ import (
"time" "time"
"google.golang.org/grpc" "google.golang.org/grpc"
gbalancer "google.golang.org/grpc/balancer"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/armon/go-metrics" "github.com/armon/go-metrics"
"github.com/hashicorp/consul/agent/grpc-internal/balancer"
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware" agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/pool"
@ -22,8 +22,8 @@ import (
// grpcServiceConfig is provided as the default service config. // grpcServiceConfig is provided as the default service config.
// //
// It configures our custom balancer (via the %s directive to interpolate its // It configures our custom balancer which will automatically switch servers
// name) which will automatically switch servers on error. // on error.
// //
// It also enables gRPC's built-in automatic retries for RESOURCE_EXHAUSTED // It also enables gRPC's built-in automatic retries for RESOURCE_EXHAUSTED
// errors *only*, as this is the status code servers will return for an // errors *only*, as this is the status code servers will return for an
@ -41,7 +41,7 @@ import (
// but we're working on generating them automatically from the protobuf files // but we're working on generating them automatically from the protobuf files
const grpcServiceConfig = ` const grpcServiceConfig = `
{ {
"loadBalancingConfig": [{"%s":{}}], "loadBalancingConfig": [{"` + balancer.BuilderName + `":{}}],
"methodConfig": [ "methodConfig": [
{ {
"name": [{}], "name": [{}],
@ -131,12 +131,11 @@ const grpcServiceConfig = `
// ClientConnPool creates and stores a connection for each datacenter. // ClientConnPool creates and stores a connection for each datacenter.
type ClientConnPool struct { type ClientConnPool struct {
dialer dialer dialer dialer
servers ServerLocator servers ServerLocator
gwResolverDep gatewayResolverDep gwResolverDep gatewayResolverDep
conns map[string]*grpc.ClientConn conns map[string]*grpc.ClientConn
connsLock sync.Mutex connsLock sync.Mutex
balancerBuilder gbalancer.Builder
} }
type ServerLocator interface { type ServerLocator interface {
@ -198,21 +197,14 @@ type ClientConnPoolConfig struct {
// DialingFromDatacenter is the datacenter of the consul agent using this // DialingFromDatacenter is the datacenter of the consul agent using this
// pool. // pool.
DialingFromDatacenter string DialingFromDatacenter string
// BalancerBuilder is a builder for the gRPC balancer that will be used.
BalancerBuilder gbalancer.Builder
} }
// NewClientConnPool create new GRPC client pool to connect to servers using // NewClientConnPool create new GRPC client pool to connect to servers using
// GRPC over RPC. // GRPC over RPC.
func NewClientConnPool(cfg ClientConnPoolConfig) *ClientConnPool { func NewClientConnPool(cfg ClientConnPoolConfig) *ClientConnPool {
if cfg.BalancerBuilder == nil {
panic("missing required BalancerBuilder")
}
c := &ClientConnPool{ c := &ClientConnPool{
servers: cfg.Servers, servers: cfg.Servers,
conns: make(map[string]*grpc.ClientConn), conns: make(map[string]*grpc.ClientConn),
balancerBuilder: cfg.BalancerBuilder,
} }
c.dialer = newDialer(cfg, &c.gwResolverDep) c.dialer = newDialer(cfg, &c.gwResolverDep)
return c return c
@ -251,9 +243,7 @@ func (c *ClientConnPool) dial(datacenter string, serverType string) (*grpc.Clien
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(c.dialer), grpc.WithContextDialer(c.dialer),
grpc.WithStatsHandler(agentmiddleware.NewStatsHandler(metrics.Default(), metricsLabels)), grpc.WithStatsHandler(agentmiddleware.NewStatsHandler(metrics.Default(), metricsLabels)),
grpc.WithDefaultServiceConfig( grpc.WithDefaultServiceConfig(grpcServiceConfig),
fmt.Sprintf(grpcServiceConfig, c.balancerBuilder.Name()),
),
// Keep alive parameters are based on the same default ones we used for // Keep alive parameters are based on the same default ones we used for
// Yamux. These are somewhat arbitrary but we did observe in scale testing // Yamux. These are somewhat arbitrary but we did observe in scale testing
// that the gRPC defaults (servers send keepalives only every 2 hours, // that the gRPC defaults (servers send keepalives only every 2 hours,

View File

@ -13,7 +13,6 @@ import (
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
gbalancer "google.golang.org/grpc/balancer"
"github.com/hashicorp/consul/agent/grpc-internal/balancer" "github.com/hashicorp/consul/agent/grpc-internal/balancer"
"github.com/hashicorp/consul/agent/grpc-internal/resolver" "github.com/hashicorp/consul/agent/grpc-internal/resolver"
@ -143,7 +142,8 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
// if this test is failing because of expired certificates // if this test is failing because of expired certificates
// use the procedure in test/CA-GENERATION.md // use the procedure in test/CA-GENERATION.md
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb)
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{ tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
InternalRPC: tlsutil.ProtocolConfig{ InternalRPC: tlsutil.ProtocolConfig{
@ -168,7 +168,6 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
UseTLSForDC: tlsConf.UseTLS, UseTLSForDC: tlsConf.UseTLS,
DialingFromServer: true, DialingFromServer: true,
DialingFromDatacenter: "dc1", DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder(t, res.Authority()),
}) })
conn, err := pool.ClientConn("dc1") conn, err := pool.ClientConn("dc1")
@ -191,7 +190,8 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T)
gwAddr := ipaddr.FormatAddressPort("127.0.0.1", freeport.GetOne(t)) gwAddr := ipaddr.FormatAddressPort("127.0.0.1", freeport.GetOne(t))
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb)
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{ tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
InternalRPC: tlsutil.ProtocolConfig{ InternalRPC: tlsutil.ProtocolConfig{
@ -244,7 +244,6 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T)
UseTLSForDC: tlsConf.UseTLS, UseTLSForDC: tlsConf.UseTLS,
DialingFromServer: true, DialingFromServer: true,
DialingFromDatacenter: "dc2", DialingFromDatacenter: "dc2",
BalancerBuilder: balancerBuilder(t, res.Authority()),
}) })
pool.SetGatewayResolver(func(addr string) string { pool.SetGatewayResolver(func(addr string) string {
return gwAddr return gwAddr
@ -267,13 +266,13 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T)
func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
count := 4 count := 4
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb)
pool := NewClientConnPool(ClientConnPoolConfig{ pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res, Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue, UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true, DialingFromServer: true,
DialingFromDatacenter: "dc1", DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder(t, res.Authority()),
}) })
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
@ -303,13 +302,13 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) { func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
count := 3 count := 3
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb)
pool := NewClientConnPool(ClientConnPoolConfig{ pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res, Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue, UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true, DialingFromServer: true,
DialingFromDatacenter: "dc1", DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder(t, res.Authority()),
}) })
var servers []testServer var servers []testServer
@ -356,13 +355,13 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
dcs := []string{"dc1", "dc2", "dc3"} dcs := []string{"dc1", "dc2", "dc3"}
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb)
pool := NewClientConnPool(ClientConnPoolConfig{ pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res, Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue, UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true, DialingFromServer: true,
DialingFromDatacenter: "dc1", DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder(t, res.Authority()),
}) })
for _, dc := range dcs { for _, dc := range dcs {
@ -386,18 +385,11 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
} }
} }
func registerWithGRPC(t *testing.T, b *resolver.ServerResolverBuilder) { func registerWithGRPC(t *testing.T, rb *resolver.ServerResolverBuilder, bb *balancer.Builder) {
resolver.Register(b) resolver.Register(rb)
bb.Register()
t.Cleanup(func() { t.Cleanup(func() {
resolver.Deregister(b.Authority()) resolver.Deregister(rb.Authority())
bb.Deregister()
}) })
} }
func balancerBuilder(t *testing.T, name string) gbalancer.Builder {
t.Helper()
bb := balancer.NewBuilder(name, testutil.Logger(t))
bb.Register()
return bb
}

View File

@ -6,6 +6,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/types" "github.com/hashicorp/consul/types"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
@ -13,6 +14,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/hashicorp/consul/agent/grpc-internal/balancer"
"github.com/hashicorp/consul/agent/grpc-internal/resolver" "github.com/hashicorp/consul/agent/grpc-internal/resolver"
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" "github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
) )
@ -27,7 +29,8 @@ func TestHandler_PanicRecoveryInterceptor(t *testing.T) {
}) })
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb)
srv := newPanicTestServer(t, logger, "server-1", "dc1", nil) srv := newPanicTestServer(t, logger, "server-1", "dc1", nil)
res.AddServer(types.AreaWAN, srv.Metadata()) res.AddServer(types.AreaWAN, srv.Metadata())
@ -38,7 +41,6 @@ func TestHandler_PanicRecoveryInterceptor(t *testing.T) {
UseTLSForDC: useTLSForDcAlwaysTrue, UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true, DialingFromServer: true,
DialingFromDatacenter: "dc1", DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder(t, res.Authority()),
}) })
conn, err := pool.ClientConn("dc1") conn, err := pool.ClientConn("dc1")

View File

@ -1693,8 +1693,9 @@ func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps {
Datacenter: c.Datacenter, Datacenter: c.Datacenter,
} }
balancerBuilder := balancer.NewBuilder(t.Name(), testutil.Logger(t)) balancerBuilder := balancer.NewBuilder(builder.Authority(), testutil.Logger(t))
balancerBuilder.Register() balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
return consul.Deps{ return consul.Deps{
EventPublisher: stream.NewEventPublisher(10 * time.Second), EventPublisher: stream.NewEventPublisher(10 * time.Second),
@ -1709,7 +1710,6 @@ func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps {
UseTLSForDC: tls.UseTLS, UseTLSForDC: tls.UseTLS,
DialingFromServer: true, DialingFromServer: true,
DialingFromDatacenter: c.Datacenter, DialingFromDatacenter: c.Datacenter,
BalancerBuilder: balancerBuilder,
}), }),
LeaderForwarder: builder, LeaderForwarder: builder,
EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c), EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c),

View File

@ -53,6 +53,8 @@ type BaseDeps struct {
Cache *cache.Cache Cache *cache.Cache
ViewStore *submatview.Store ViewStore *submatview.Store
WatchedFiles []string WatchedFiles []string
deregisterBalancer, deregisterResolver func()
} }
type ConfigLoader func(source config.Source) (config.LoadResult, error) type ConfigLoader func(source config.Source) (config.LoadResult, error)
@ -122,14 +124,16 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl
Authority: cfg.Datacenter + "." + string(cfg.NodeID), Authority: cfg.Datacenter + "." + string(cfg.NodeID),
}) })
resolver.Register(resolverBuilder) resolver.Register(resolverBuilder)
d.deregisterResolver = func() {
resolver.Deregister(resolverBuilder.Authority())
}
balancerBuilder := balancer.NewBuilder( balancerBuilder := balancer.NewBuilder(
// Balancer name doesn't really matter, we set it to the resolver authority
// to keep it unique for tests.
resolverBuilder.Authority(), resolverBuilder.Authority(),
d.Logger.Named("grpc.balancer"), d.Logger.Named("grpc.balancer"),
) )
balancerBuilder.Register() balancerBuilder.Register()
d.deregisterBalancer = balancerBuilder.Deregister
d.GRPCConnPool = grpcInt.NewClientConnPool(grpcInt.ClientConnPoolConfig{ d.GRPCConnPool = grpcInt.NewClientConnPool(grpcInt.ClientConnPoolConfig{
Servers: resolverBuilder, Servers: resolverBuilder,
@ -139,7 +143,6 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl
UseTLSForDC: d.TLSConfigurator.UseTLS, UseTLSForDC: d.TLSConfigurator.UseTLS,
DialingFromServer: cfg.ServerMode, DialingFromServer: cfg.ServerMode,
DialingFromDatacenter: cfg.Datacenter, DialingFromDatacenter: cfg.Datacenter,
BalancerBuilder: balancerBuilder,
}) })
d.LeaderForwarder = resolverBuilder d.LeaderForwarder = resolverBuilder
@ -189,6 +192,20 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl
return d, nil return d, nil
} }
// Close cleans up any state and goroutines associated to bd's members not
// handled by something else (e.g. the agent stop channel).
func (bd BaseDeps) Close() {
bd.AutoConfig.Stop()
bd.MetricsConfig.Cancel()
if fn := bd.deregisterBalancer; fn != nil {
fn()
}
if fn := bd.deregisterResolver; fn != nil {
fn()
}
}
// grpcLogInitOnce because the test suite will call NewBaseDeps in many tests and // grpcLogInitOnce because the test suite will call NewBaseDeps in many tests and
// causes data races when it is re-initialized. // causes data races when it is re-initialized.
var grpcLogInitOnce sync.Once var grpcLogInitOnce sync.Once