grpc: fix data race in balancer registration (#16229)

Registering gRPC balancers is thread-unsafe because they are stored in a
global map variable that is accessed without holding a lock. Therefore,
it's expected that balancers are registered _once_ at the beginning of
your program (e.g. in a package `init` function) and certainly not after
you've started dialing connections, etc.

> NOTE: this function must only be called during initialization time
> (i.e. in an init() function), and is not thread-safe.

While this is fine for us in production, it's challenging for tests that
spin up multiple agents in-memory. We currently register a balancer per-
agent which holds agent-specific state that cannot safely be shared.

This commit introduces our own registry that _is_ thread-safe, and
implements the Builder interface such that we can call gRPC's `Register`
method once, on start-up. It uses the same pattern as our resolver
registry where we use the dial target's host (aka "authority"), which is
unique per-agent, to determine which builder to use.
This commit is contained in:
Dan Upton 2023-02-28 10:18:38 +00:00 committed by GitHub
parent 3cbbd63ba2
commit 73b9b407ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 179 additions and 104 deletions

View File

@ -1599,10 +1599,7 @@ func (a *Agent) ShutdownAgent() error {
a.stopLicenseManager()
// this would be cancelled anyways (by the closing of the shutdown ch) but
// this should help them to be stopped more quickly
a.baseDeps.AutoConfig.Stop()
a.baseDeps.MetricsConfig.Cancel()
a.baseDeps.Close()
a.stateLock.Lock()
defer a.stateLock.Unlock()

View File

@ -522,9 +522,13 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter))
resolver.Register(resolverBuilder)
t.Cleanup(func() {
resolver.Deregister(resolverBuilder.Authority())
})
balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
r := router.NewRouter(
logger,
@ -559,7 +563,6 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
UseTLSForDC: tls.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: c.Datacenter,
BalancerBuilder: balancerBuilder,
}),
LeaderForwarder: resolverBuilder,
NewRequestRecorderFunc: middleware.NewRequestRecorder,

View File

@ -1165,7 +1165,7 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) {
var conn *grpc.ClientConn
{
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, func(c *Config) {
client, resolverBuilder := newClientWithGRPCPlumbing(t, func(c *Config) {
c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc1"
c.RPCConfig.EnableStreaming = true
@ -1177,7 +1177,6 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) {
Servers: resolverBuilder,
DialingFromServer: false,
DialingFromDatacenter: "dc2",
BalancerBuilder: balancerBuilder,
})
conn, err = pool.ClientConn("dc2")

View File

@ -39,7 +39,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
require.NoError(t, err)
defer server.Shutdown()
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
client, resolverBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
// Try to join
testrpc.WaitForLeader(t, server.RPC, "dc1")
@ -71,7 +71,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
@ -116,7 +115,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
@ -191,7 +189,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
defer server.Shutdown()
// Set up a client with valid certs and verify_outgoing = true
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
client, resolverBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
testrpc.WaitForLeader(t, server.RPC, "dc1")
@ -204,7 +202,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
@ -284,7 +281,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
codec := rpcClient(t, server)
defer codec.Close()
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t)
client, resolverBuilder := newClientWithGRPCPlumbing(t)
// Try to join
testrpc.WaitForLeader(t, server.RPC, "dc1")
@ -346,7 +343,6 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
@ -376,7 +372,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
"at least some of the subscribers should have received non-snapshot updates")
}
func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder, *balancer.Builder) {
func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) {
_, config := testClientConfig(t)
for _, op := range ops {
op(config)
@ -392,6 +388,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re
balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
deps := newDefaultDeps(t, config)
deps.Router = router.NewRouter(
@ -406,7 +403,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re
t.Cleanup(func() {
client.Shutdown()
})
return client, resolverBuilder, balancerBuilder
return client, resolverBuilder
}
type testLogger interface {

View File

@ -65,21 +65,25 @@ import (
"google.golang.org/grpc/status"
)
// NewBuilder constructs a new Builder with the given name.
func NewBuilder(name string, logger hclog.Logger) *Builder {
// NewBuilder constructs a new Builder. Calling Register will add the Builder
// to our global registry under the given "authority" such that it will be used
// when dialing targets in the form "consul-internal://<authority>/...", this
// allows us to add and remove balancers for different in-memory agents during
// tests.
func NewBuilder(authority string, logger hclog.Logger) *Builder {
return &Builder{
name: name,
logger: logger,
byTarget: make(map[string]*list.List),
shuffler: randomShuffler(),
authority: authority,
logger: logger,
byTarget: make(map[string]*list.List),
shuffler: randomShuffler(),
}
}
// Builder implements gRPC's balancer.Builder interface to construct balancers.
type Builder struct {
name string
logger hclog.Logger
shuffler shuffler
authority string
logger hclog.Logger
shuffler shuffler
mu sync.Mutex
byTarget map[string]*list.List
@ -129,19 +133,15 @@ func (b *Builder) removeBalancer(targetURL string, elem *list.Element) {
}
}
// Name implements the gRPC Balancer interface by returning its given name.
func (b *Builder) Name() string { return b.name }
// gRPC's balancer.Register method is not thread-safe, so we guard our calls
// with a global lock (as it may be called from parallel tests).
var registerLock sync.Mutex
// Register the Builder in gRPC's global registry using its given name.
// Register the Builder in our global registry. Users should call Deregister
// when finished using the Builder to clean-up global state.
func (b *Builder) Register() {
registerLock.Lock()
defer registerLock.Unlock()
globalRegistry.register(b.authority, b)
}
gbalancer.Register(b)
// Deregister the Builder from our global registry to clean up state.
func (b *Builder) Deregister() {
globalRegistry.deregister(b.authority)
}
// Rebalance randomizes the priority order of servers for the given target to

View File

@ -21,6 +21,8 @@ import (
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry"
@ -34,12 +36,13 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")
target, _ := stubResolver(t, server1, server2)
target, authority, _ := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
conn := dial(t, target, balancerBuilder)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)
var serverName string
@ -78,12 +81,13 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")
target, _ := stubResolver(t, server1, server2)
target, authority, _ := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
conn := dial(t, target, balancerBuilder)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)
// Figure out which server we're talking to now, and which we should switch to.
@ -123,10 +127,11 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")
target, _ := stubResolver(t, server1, server2)
target, authority, _ := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
// Provide a custom prioritizer that causes Rebalance to choose whichever
// server didn't get our first request.
@ -137,7 +142,7 @@ func TestBalancer(t *testing.T) {
})
}
conn := dial(t, target, balancerBuilder)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)
// Figure out which server we're talking to now.
@ -177,12 +182,13 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")
target, res := stubResolver(t, server1, server2)
target, authority, res := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
conn := dial(t, target, balancerBuilder)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)
// Figure out which server we're talking to now.
@ -233,7 +239,7 @@ func TestBalancer(t *testing.T) {
})
}
func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
func stubResolver(t *testing.T, servers ...*server) (string, string, *manual.Resolver) {
t.Helper()
addresses := make([]resolver.Address, len(servers))
@ -249,7 +255,10 @@ func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
resolver.Register(r)
t.Cleanup(func() { resolver.UnregisterForTesting(scheme) })
return fmt.Sprintf("%s://", scheme), r
authority, err := uuid.GenerateUUID()
require.NoError(t, err)
return fmt.Sprintf("%s://%s", scheme, authority), authority, r
}
func runServer(t *testing.T, name string) *server {
@ -309,12 +318,12 @@ func (s *server) Something(context.Context, *testservice.Req) (*testservice.Resp
return &testservice.Resp{ServerName: s.name}, nil
}
func dial(t *testing.T, target string, builder *Builder) *grpc.ClientConn {
func dial(t *testing.T, target string) *grpc.ClientConn {
conn, err := grpc.Dial(
target,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, builder.Name()),
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, BuilderName),
),
)
t.Cleanup(func() {

View File

@ -0,0 +1,69 @@
package balancer
import (
"fmt"
"sync"
gbalancer "google.golang.org/grpc/balancer"
)
// BuilderName should be given in gRPC service configuration to enable our
// custom balancer. It refers to this package's global registry, rather than
// an instance of Builder to enable us to add and remove builders at runtime,
// specifically during tests.
const BuilderName = "consul-internal"
// gRPC's balancer.Register method is thread-unsafe because it mutates a global
// map without holding a lock. As such, it's expected that you register custom
// balancers once at the start of your program (e.g. a package init function).
//
// In production, this is fine. Agents register a single instance of our builder
// and use it for the duration. Tests are where this becomes problematic, as we
// spin up several agents in-memory and register/deregister a builder for each,
// with its own agent-specific state, logger, etc.
//
// To avoid data races, we call gRPC's Register method once, on-package init,
// with a global registry struct that implements the Builder interface but
// delegates the building to N instances of our Builder that are registered and
// deregistered at runtime. We the dial target's host (aka "authority") which
// is unique per-agent to pick the correct builder.
func init() {
gbalancer.Register(globalRegistry)
}
var globalRegistry = &registry{
byAuthority: make(map[string]*Builder),
}
type registry struct {
mu sync.RWMutex
byAuthority map[string]*Builder
}
func (r *registry) Build(cc gbalancer.ClientConn, opts gbalancer.BuildOptions) gbalancer.Balancer {
r.mu.RLock()
defer r.mu.RUnlock()
auth := opts.Target.URL.Host
builder, ok := r.byAuthority[auth]
if !ok {
panic(fmt.Sprintf("no gRPC balancer builder registered for authority: %q", auth))
}
return builder.Build(cc, opts)
}
func (r *registry) Name() string { return BuilderName }
func (r *registry) register(auth string, builder *Builder) {
r.mu.Lock()
defer r.mu.Unlock()
r.byAuthority[auth] = builder
}
func (r *registry) deregister(auth string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.byAuthority, auth)
}

View File

@ -8,12 +8,12 @@ import (
"time"
"google.golang.org/grpc"
gbalancer "google.golang.org/grpc/balancer"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/agent/grpc-internal/balancer"
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/pool"
@ -22,8 +22,8 @@ import (
// grpcServiceConfig is provided as the default service config.
//
// It configures our custom balancer (via the %s directive to interpolate its
// name) which will automatically switch servers on error.
// It configures our custom balancer which will automatically switch servers
// on error.
//
// It also enables gRPC's built-in automatic retries for RESOURCE_EXHAUSTED
// errors *only*, as this is the status code servers will return for an
@ -41,7 +41,7 @@ import (
// but we're working on generating them automatically from the protobuf files
const grpcServiceConfig = `
{
"loadBalancingConfig": [{"%s":{}}],
"loadBalancingConfig": [{"` + balancer.BuilderName + `":{}}],
"methodConfig": [
{
"name": [{}],
@ -131,12 +131,11 @@ const grpcServiceConfig = `
// ClientConnPool creates and stores a connection for each datacenter.
type ClientConnPool struct {
dialer dialer
servers ServerLocator
gwResolverDep gatewayResolverDep
conns map[string]*grpc.ClientConn
connsLock sync.Mutex
balancerBuilder gbalancer.Builder
dialer dialer
servers ServerLocator
gwResolverDep gatewayResolverDep
conns map[string]*grpc.ClientConn
connsLock sync.Mutex
}
type ServerLocator interface {
@ -198,21 +197,14 @@ type ClientConnPoolConfig struct {
// DialingFromDatacenter is the datacenter of the consul agent using this
// pool.
DialingFromDatacenter string
// BalancerBuilder is a builder for the gRPC balancer that will be used.
BalancerBuilder gbalancer.Builder
}
// NewClientConnPool create new GRPC client pool to connect to servers using
// GRPC over RPC.
func NewClientConnPool(cfg ClientConnPoolConfig) *ClientConnPool {
if cfg.BalancerBuilder == nil {
panic("missing required BalancerBuilder")
}
c := &ClientConnPool{
servers: cfg.Servers,
conns: make(map[string]*grpc.ClientConn),
balancerBuilder: cfg.BalancerBuilder,
servers: cfg.Servers,
conns: make(map[string]*grpc.ClientConn),
}
c.dialer = newDialer(cfg, &c.gwResolverDep)
return c
@ -251,9 +243,7 @@ func (c *ClientConnPool) dial(datacenter string, serverType string) (*grpc.Clien
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(c.dialer),
grpc.WithStatsHandler(agentmiddleware.NewStatsHandler(metrics.Default(), metricsLabels)),
grpc.WithDefaultServiceConfig(
fmt.Sprintf(grpcServiceConfig, c.balancerBuilder.Name()),
),
grpc.WithDefaultServiceConfig(grpcServiceConfig),
// Keep alive parameters are based on the same default ones we used for
// Yamux. These are somewhat arbitrary but we did observe in scale testing
// that the gRPC defaults (servers send keepalives only every 2 hours,

View File

@ -13,7 +13,6 @@ import (
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
gbalancer "google.golang.org/grpc/balancer"
"github.com/hashicorp/consul/agent/grpc-internal/balancer"
"github.com/hashicorp/consul/agent/grpc-internal/resolver"
@ -143,7 +142,8 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
// if this test is failing because of expired certificates
// use the procedure in test/CA-GENERATION.md
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb)
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
InternalRPC: tlsutil.ProtocolConfig{
@ -168,7 +168,6 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
UseTLSForDC: tlsConf.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder(t, res.Authority()),
})
conn, err := pool.ClientConn("dc1")
@ -191,7 +190,8 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T)
gwAddr := ipaddr.FormatAddressPort("127.0.0.1", freeport.GetOne(t))
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb)
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
InternalRPC: tlsutil.ProtocolConfig{
@ -244,7 +244,6 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T)
UseTLSForDC: tlsConf.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc2",
BalancerBuilder: balancerBuilder(t, res.Authority()),
})
pool.SetGatewayResolver(func(addr string) string {
return gwAddr
@ -267,13 +266,13 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T)
func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
count := 4
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder(t, res.Authority()),
})
for i := 0; i < count; i++ {
@ -303,13 +302,13 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
count := 3
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder(t, res.Authority()),
})
var servers []testServer
@ -356,13 +355,13 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
dcs := []string{"dc1", "dc2", "dc3"}
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder(t, res.Authority()),
})
for _, dc := range dcs {
@ -386,18 +385,11 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
}
}
func registerWithGRPC(t *testing.T, b *resolver.ServerResolverBuilder) {
resolver.Register(b)
func registerWithGRPC(t *testing.T, rb *resolver.ServerResolverBuilder, bb *balancer.Builder) {
resolver.Register(rb)
bb.Register()
t.Cleanup(func() {
resolver.Deregister(b.Authority())
resolver.Deregister(rb.Authority())
bb.Deregister()
})
}
func balancerBuilder(t *testing.T, name string) gbalancer.Builder {
t.Helper()
bb := balancer.NewBuilder(name, testutil.Logger(t))
bb.Register()
return bb
}

View File

@ -6,6 +6,7 @@ import (
"testing"
"time"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/go-hclog"
@ -13,6 +14,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/consul/agent/grpc-internal/balancer"
"github.com/hashicorp/consul/agent/grpc-internal/resolver"
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
)
@ -27,7 +29,8 @@ func TestHandler_PanicRecoveryInterceptor(t *testing.T) {
})
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb)
srv := newPanicTestServer(t, logger, "server-1", "dc1", nil)
res.AddServer(types.AreaWAN, srv.Metadata())
@ -38,7 +41,6 @@ func TestHandler_PanicRecoveryInterceptor(t *testing.T) {
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder(t, res.Authority()),
})
conn, err := pool.ClientConn("dc1")

View File

@ -1693,8 +1693,9 @@ func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps {
Datacenter: c.Datacenter,
}
balancerBuilder := balancer.NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := balancer.NewBuilder(builder.Authority(), testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
return consul.Deps{
EventPublisher: stream.NewEventPublisher(10 * time.Second),
@ -1709,7 +1710,6 @@ func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps {
UseTLSForDC: tls.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: c.Datacenter,
BalancerBuilder: balancerBuilder,
}),
LeaderForwarder: builder,
EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c),

View File

@ -53,6 +53,8 @@ type BaseDeps struct {
Cache *cache.Cache
ViewStore *submatview.Store
WatchedFiles []string
deregisterBalancer, deregisterResolver func()
}
type ConfigLoader func(source config.Source) (config.LoadResult, error)
@ -122,14 +124,16 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl
Authority: cfg.Datacenter + "." + string(cfg.NodeID),
})
resolver.Register(resolverBuilder)
d.deregisterResolver = func() {
resolver.Deregister(resolverBuilder.Authority())
}
balancerBuilder := balancer.NewBuilder(
// Balancer name doesn't really matter, we set it to the resolver authority
// to keep it unique for tests.
resolverBuilder.Authority(),
d.Logger.Named("grpc.balancer"),
)
balancerBuilder.Register()
d.deregisterBalancer = balancerBuilder.Deregister
d.GRPCConnPool = grpcInt.NewClientConnPool(grpcInt.ClientConnPoolConfig{
Servers: resolverBuilder,
@ -139,7 +143,6 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl
UseTLSForDC: d.TLSConfigurator.UseTLS,
DialingFromServer: cfg.ServerMode,
DialingFromDatacenter: cfg.Datacenter,
BalancerBuilder: balancerBuilder,
})
d.LeaderForwarder = resolverBuilder
@ -189,6 +192,20 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl
return d, nil
}
// Close cleans up any state and goroutines associated to bd's members not
// handled by something else (e.g. the agent stop channel).
func (bd BaseDeps) Close() {
bd.AutoConfig.Stop()
bd.MetricsConfig.Cancel()
if fn := bd.deregisterBalancer; fn != nil {
fn()
}
if fn := bd.deregisterResolver; fn != nil {
fn()
}
}
// grpcLogInitOnce because the test suite will call NewBaseDeps in many tests and
// causes data races when it is re-initialized.
var grpcLogInitOnce sync.Once