Merge pull request #10334 from hashicorp/dnephin/grpc-fix-resolver-data-race

grpc: fix resolver data race
This commit is contained in:
Daniel Nephin 2021-06-02 13:23:27 -04:00 committed by hc-github-team-consul-core
parent 8197d2c063
commit 749a0b01c3
8 changed files with 84 additions and 59 deletions

View File

@ -12,7 +12,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
gogrpc "google.golang.org/grpc" gogrpc "google.golang.org/grpc"
grpcresolver "google.golang.org/grpc/resolver"
grpc "github.com/hashicorp/consul/agent/grpc" grpc "github.com/hashicorp/consul/agent/grpc"
"github.com/hashicorp/consul/agent/grpc/resolver" "github.com/hashicorp/consul/agent/grpc/resolver"
@ -338,8 +337,11 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
} }
func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) { func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) {
builder := resolver.NewServerResolverBuilder(resolver.Config{Scheme: t.Name()}) builder := resolver.NewServerResolverBuilder(resolver.Config{Authority: t.Name()})
registerWithGRPC(builder) resolver.Register(builder)
t.Cleanup(func() {
resolver.Deregister(builder.Authority())
})
_, config := testClientConfig(t) _, config := testClientConfig(t)
for _, op := range ops { for _, op := range ops {
@ -361,19 +363,6 @@ func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *re
return client, builder return client, builder
} }
var grpcRegisterLock sync.Mutex
// registerWithGRPC registers the grpc/resolver.Builder as a grpc/resolver.
// This function exists to synchronize registrations with a lock.
// grpc/resolver.Register expects all registration to happen at init and does
// not allow for concurrent registration. This function exists to support
// parallel testing.
func registerWithGRPC(b grpcresolver.Builder) {
grpcRegisterLock.Lock()
defer grpcRegisterLock.Unlock()
grpcresolver.Register(b)
}
type testLogger interface { type testLogger interface {
Logf(format string, args ...interface{}) Logf(format string, args ...interface{})
} }

View File

@ -25,10 +25,10 @@ type ClientConnPool struct {
type ServerLocator interface { type ServerLocator interface {
// ServerForAddr is used to look up server metadata from an address. // ServerForAddr is used to look up server metadata from an address.
ServerForAddr(addr string) (*metadata.Server, error) ServerForAddr(addr string) (*metadata.Server, error)
// Scheme returns the url scheme to use to dial the server. This is primarily // Authority returns the target authority to use to dial the server. This is primarily
// needed for testing multiple agents in parallel, because gRPC requires the // needed for testing multiple agents in parallel, because gRPC requires the
// resolver to be registered globally. // resolver to be registered globally.
Scheme() string Authority() string
} }
// TLSWrapper wraps a non-TLS connection and returns a connection with TLS // TLSWrapper wraps a non-TLS connection and returns a connection with TLS
@ -58,7 +58,7 @@ func (c *ClientConnPool) ClientConn(datacenter string) (*grpc.ClientConn, error)
} }
conn, err := grpc.Dial( conn, err := grpc.Dial(
fmt.Sprintf("%s:///server.%s", c.servers.Scheme(), datacenter), fmt.Sprintf("consul://%s/server.%s", c.servers.Authority(), datacenter),
// use WithInsecure mode here because we handle the TLS wrapping in the // use WithInsecure mode here because we handle the TLS wrapping in the
// custom dialer based on logic around whether the server has TLS enabled. // custom dialer based on logic around whether the server has TLS enabled.
grpc.WithInsecure(), grpc.WithInsecure(),

View File

@ -117,7 +117,7 @@ func newConfig(t *testing.T) resolver.Config {
n := t.Name() n := t.Name()
s := strings.Replace(n, "/", "", -1) s := strings.Replace(n, "/", "", -1)
s = strings.Replace(s, "_", "", -1) s = strings.Replace(s, "_", "", -1)
return resolver.Config{Scheme: strings.ToLower(s)} return resolver.Config{Authority: strings.ToLower(s)}
} }
func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) { func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) {
@ -195,3 +195,10 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
require.Equal(t, resp.Datacenter, dc) require.Equal(t, resp.Datacenter, dc)
} }
} }
func registerWithGRPC(t *testing.T, b *resolver.ServerResolverBuilder) {
resolver.Register(b)
t.Cleanup(func() {
resolver.Deregister(b.Authority())
})
}

View File

@ -0,0 +1,54 @@
package resolver
import (
"fmt"
"sync"
"google.golang.org/grpc/resolver"
)
// registry of ServerResolverBuilder. This type exists because grpc requires that
// resolvers are registered globally before any requests are made. This is
// incompatible with our resolver implementation and testing strategy, which
// requires a different Resolver for each test.
type registry struct {
lock sync.RWMutex
byAuthority map[string]*ServerResolverBuilder
}
func (r *registry) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) {
r.lock.RLock()
defer r.lock.RUnlock()
res, ok := r.byAuthority[target.Authority]
if !ok {
return nil, fmt.Errorf("no resolver registered for %v", target.Authority)
}
return res.Build(target, cc, opts)
}
func (r *registry) Scheme() string {
return "consul"
}
var _ resolver.Builder = (*registry)(nil)
var reg = &registry{byAuthority: make(map[string]*ServerResolverBuilder)}
func init() {
resolver.Register(reg)
}
// Register a ServerResolverBuilder with the global registry.
func Register(res *ServerResolverBuilder) {
reg.lock.Lock()
defer reg.lock.Unlock()
reg.byAuthority[res.Authority()] = res
}
// Deregister the ServerResolverBuilder associated with the authority. Only used
// for testing.
func Deregister(authority string) {
reg.lock.Lock()
defer reg.lock.Unlock()
delete(reg.byAuthority, authority)
}

View File

@ -15,9 +15,7 @@ import (
// ServerResolverBuilder tracks the current server list and keeps any // ServerResolverBuilder tracks the current server list and keeps any
// ServerResolvers updated when changes occur. // ServerResolvers updated when changes occur.
type ServerResolverBuilder struct { type ServerResolverBuilder struct {
// scheme used to query the server. Defaults to consul. Used to support cfg Config
// parallel testing because gRPC registers resolvers globally.
scheme string
// servers is an index of Servers by Server.ID. The map contains server IDs // servers is an index of Servers by Server.ID. The map contains server IDs
// for all datacenters. // for all datacenters.
servers map[string]*metadata.Server servers map[string]*metadata.Server
@ -28,25 +26,22 @@ type ServerResolverBuilder struct {
lock sync.RWMutex lock sync.RWMutex
} }
var _ resolver.Builder = (*ServerResolverBuilder)(nil)
type Config struct { type Config struct {
// Scheme used to connect to the server. Defaults to consul. // Authority used to query the server. Defaults to "". Used to support
Scheme string // parallel testing because gRPC registers resolvers globally.
Authority string
} }
func NewServerResolverBuilder(cfg Config) *ServerResolverBuilder { func NewServerResolverBuilder(cfg Config) *ServerResolverBuilder {
if cfg.Scheme == "" {
cfg.Scheme = "consul"
}
return &ServerResolverBuilder{ return &ServerResolverBuilder{
scheme: cfg.Scheme, cfg: cfg,
servers: make(map[string]*metadata.Server), servers: make(map[string]*metadata.Server),
resolvers: make(map[resolver.ClientConn]*serverResolver), resolvers: make(map[resolver.ClientConn]*serverResolver),
} }
} }
// Rebalance shuffles the server list for resolvers in all datacenters. // NewRebalancer returns a function which shuffles the server list for resolvers
// in all datacenters.
func (s *ServerResolverBuilder) NewRebalancer(dc string) func() { func (s *ServerResolverBuilder) NewRebalancer(dc string) func() {
shuffler := rand.New(rand.NewSource(time.Now().UnixNano())) shuffler := rand.New(rand.NewSource(time.Now().UnixNano()))
return func() { return func() {
@ -112,7 +107,9 @@ func (s *ServerResolverBuilder) Build(target resolver.Target, cc resolver.Client
return resolver, nil return resolver, nil
} }
func (s *ServerResolverBuilder) Scheme() string { return s.scheme } func (s *ServerResolverBuilder) Authority() string {
return s.cfg.Authority
}
// AddServer updates the resolvers' states to include the new server's address. // AddServer updates the resolvers' states to include the new server's address.
func (s *ServerResolverBuilder) AddServer(server *metadata.Server) { func (s *ServerResolverBuilder) AddServer(server *metadata.Server) {

View File

@ -13,7 +13,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/resolver"
"github.com/hashicorp/consul/agent/grpc/internal/testservice" "github.com/hashicorp/consul/agent/grpc/internal/testservice"
"github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/metadata"
@ -167,10 +166,3 @@ func (f *fakeRPCListener) handleConn(conn net.Conn) {
conn.Close() conn.Close()
} }
} }
func registerWithGRPC(t *testing.T, b resolver.Builder) {
resolver.Register(b)
t.Cleanup(func() {
resolver.UnregisterForTesting(b.Scheme())
})
}

View File

@ -11,7 +11,6 @@ import (
"github.com/armon/go-metrics/prometheus" "github.com/armon/go-metrics/prometheus"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
grpcresolver "google.golang.org/grpc/resolver"
autoconf "github.com/hashicorp/consul/agent/auto-config" autoconf "github.com/hashicorp/consul/agent/auto-config"
"github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/cache"
@ -105,7 +104,7 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error)
d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator) d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator)
builder := resolver.NewServerResolverBuilder(resolver.Config{}) builder := resolver.NewServerResolverBuilder(resolver.Config{})
registerWithGRPC(builder) resolver.Register(builder)
d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()), d.TLSConfigurator.UseTLS) d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()), d.TLSConfigurator.UseTLS)
d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder) d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder)
@ -169,19 +168,6 @@ func newConnPool(config *config.RuntimeConfig, logger hclog.Logger, tls *tlsutil
return pool return pool
} }
var registerLock sync.Mutex
// registerWithGRPC registers the grpc/resolver.Builder as a grpc/resolver.
// This function exists to synchronize registrations with a lock.
// grpc/resolver.Register expects all registration to happen at init and does
// not allow for concurrent registration. This function exists to support
// parallel testing.
func registerWithGRPC(b grpcresolver.Builder) {
registerLock.Lock()
defer registerLock.Unlock()
grpcresolver.Register(b)
}
// getPrometheusDefs reaches into every slice of prometheus defs we've defined in each part of the agent, and appends // getPrometheusDefs reaches into every slice of prometheus defs we've defined in each part of the agent, and appends
// all of our slices into one nice slice of definitions per metric type for the Consul agent to pass to go-metrics. // all of our slices into one nice slice of definitions per metric type for the Consul agent to pass to go-metrics.
func getPrometheusDefs(cfg lib.TelemetryConfig) ([]prometheus.GaugeDefinition, []prometheus.CounterDefinition, []prometheus.SummaryDefinition) { func getPrometheusDefs(cfg lib.TelemetryConfig) ([]prometheus.GaugeDefinition, []prometheus.CounterDefinition, []prometheus.SummaryDefinition) {

View File

@ -30,7 +30,7 @@ type Store struct {
// idleTTL is the duration of time an entry should remain in the Store after the // idleTTL is the duration of time an entry should remain in the Store after the
// last request for that entry has been terminated. It is a field on the struct // last request for that entry has been terminated. It is a field on the struct
// so that it can be patched in tests without need a lock. // so that it can be patched in tests without needing a global lock.
idleTTL time.Duration idleTTL time.Duration
} }
@ -122,8 +122,8 @@ func (s *Store) Get(ctx context.Context, req Request) (Result, error) {
defer cancel() defer cancel()
result, err := materializer.getFromView(ctx, info.MinIndex) result, err := materializer.getFromView(ctx, info.MinIndex)
// context.DeadlineExceeded is translated to nil to match the behaviour of // context.DeadlineExceeded is translated to nil to match the timeout
// agent/cache.Cache.Get. // behaviour of agent/cache.Cache.Get.
if err == nil || errors.Is(err, context.DeadlineExceeded) { if err == nil || errors.Is(err, context.DeadlineExceeded) {
return result, nil return result, nil
} }