diff --git a/.changelog/15892.txt b/.changelog/15892.txt new file mode 100644 index 0000000000..972261120b --- /dev/null +++ b/.changelog/15892.txt @@ -0,0 +1,3 @@ +```release-note:enhancement +grpc: client agents will switch server on error, and automatically retry on `RESOURCE_EXHAUSTED` responses +``` diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index da1f462c30..15af555094 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -21,6 +21,7 @@ import ( "github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/grpc-external/limiter" grpc "github.com/hashicorp/consul/agent/grpc-internal" + "github.com/hashicorp/consul/agent/grpc-internal/balancer" "github.com/hashicorp/consul/agent/grpc-internal/resolver" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/router" @@ -519,9 +520,18 @@ func newDefaultDeps(t *testing.T, c *Config) Deps { tls, err := tlsutil.NewConfigurator(c.TLSConfig, logger) require.NoError(t, err, "failed to create tls configuration") - builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter)) - r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), builder) - resolver.Register(builder) + resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter)) + resolver.Register(resolverBuilder) + + balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t)) + balancerBuilder.Register() + + r := router.NewRouter( + logger, + c.Datacenter, + fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), + grpc.NewTracker(resolverBuilder, balancerBuilder), + ) connPool := &pool.ConnPool{ Server: false, @@ -544,13 +554,14 @@ func newDefaultDeps(t *testing.T, c *Config) Deps { Router: r, ConnPool: connPool, GRPCConnPool: grpc.NewClientConnPool(grpc.ClientConnPoolConfig{ - Servers: builder, + Servers: resolverBuilder, TLSWrapper: grpc.TLSWrapper(tls.OutgoingRPCWrapper()), UseTLSForDC: tls.UseTLS, DialingFromServer: true, DialingFromDatacenter: c.Datacenter, + BalancerBuilder: balancerBuilder, }), - LeaderForwarder: builder, + LeaderForwarder: resolverBuilder, NewRequestRecorderFunc: middleware.NewRequestRecorder, GetNetRPCInterceptorFunc: middleware.GetNetRPCInterceptor, EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c), diff --git a/agent/consul/rate/handler_test.go b/agent/consul/rate/handler_test.go index 318907331f..112d74541d 100644 --- a/agent/consul/rate/handler_test.go +++ b/agent/consul/rate/handler_test.go @@ -238,7 +238,6 @@ func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) { GlobalWriteConfig: writeCfg, GlobalMode: ModeEnforcing, } - logger := hclog.NewNullLogger() NewHandlerWithLimiter(*cfg, mockRateLimiter, logger) mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2) diff --git a/agent/consul/rpc_test.go b/agent/consul/rpc_test.go index 2dce38ed0e..512d9a4c3f 100644 --- a/agent/consul/rpc_test.go +++ b/agent/consul/rpc_test.go @@ -1163,7 +1163,7 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) { var conn *grpc.ClientConn { - client, builder := newClientWithGRPCResolver(t, func(c *Config) { + client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, func(c *Config) { c.Datacenter = "dc2" c.PrimaryDatacenter = "dc1" c.RPCConfig.EnableStreaming = true @@ -1172,9 +1172,10 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) { testrpc.WaitForTestAgent(t, client.RPC, "dc2", testrpc.WithToken("root")) pool := agent_grpc.NewClientConnPool(agent_grpc.ClientConnPoolConfig{ - Servers: builder, + Servers: resolverBuilder, DialingFromServer: false, DialingFromDatacenter: "dc2", + BalancerBuilder: balancerBuilder, }) conn, err = pool.ClientConn("dc2") diff --git a/agent/consul/subscribe_backend_test.go b/agent/consul/subscribe_backend_test.go index 26bd3f90b1..770f4a61d9 100644 --- a/agent/consul/subscribe_backend_test.go +++ b/agent/consul/subscribe_backend_test.go @@ -15,11 +15,13 @@ import ( gogrpc "google.golang.org/grpc" grpc "github.com/hashicorp/consul/agent/grpc-internal" + "github.com/hashicorp/consul/agent/grpc-internal/balancer" "github.com/hashicorp/consul/agent/grpc-internal/resolver" "github.com/hashicorp/consul/agent/router" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/proto/pbservice" "github.com/hashicorp/consul/proto/pbsubscribe" + "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/testrpc" ) @@ -37,7 +39,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) { require.NoError(t, err) defer server.Shutdown() - client, builder := newClientWithGRPCResolver(t, configureTLS, clientConfigVerifyOutgoing) + client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing) // Try to join testrpc.WaitForLeader(t, server.RPC, "dc1") @@ -64,11 +66,12 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) { // Start a Subscribe call to our streaming endpoint from the client. { pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{ - Servers: builder, + Servers: resolverBuilder, TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), UseTLSForDC: client.tlsConfigurator.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc1", + BalancerBuilder: balancerBuilder, }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -108,11 +111,12 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) { // Start a Subscribe call to our streaming endpoint from the server's loopback client. { pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{ - Servers: builder, + Servers: resolverBuilder, TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), UseTLSForDC: client.tlsConfigurator.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc1", + BalancerBuilder: balancerBuilder, }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -187,7 +191,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) { defer server.Shutdown() // Set up a client with valid certs and verify_outgoing = true - client, builder := newClientWithGRPCResolver(t, configureTLS, clientConfigVerifyOutgoing) + client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing) testrpc.WaitForLeader(t, server.RPC, "dc1") @@ -195,11 +199,12 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) { joinLAN(t, client, server) pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{ - Servers: builder, + Servers: resolverBuilder, TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), UseTLSForDC: client.tlsConfigurator.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc1", + BalancerBuilder: balancerBuilder, }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -279,7 +284,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T codec := rpcClient(t, server) defer codec.Close() - client, builder := newClientWithGRPCResolver(t) + client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t) // Try to join testrpc.WaitForLeader(t, server.RPC, "dc1") @@ -336,11 +341,12 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T }() pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{ - Servers: builder, + Servers: resolverBuilder, TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), UseTLSForDC: client.tlsConfigurator.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc1", + BalancerBuilder: balancerBuilder, }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -370,33 +376,37 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T "at least some of the subscribers should have received non-snapshot updates") } -func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) { +func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder, *balancer.Builder) { _, config := testClientConfig(t) for _, op := range ops { op(config) } - builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, + resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, "client."+config.Datacenter+"."+string(config.NodeID))) - resolver.Register(builder) + resolver.Register(resolverBuilder) t.Cleanup(func() { - resolver.Deregister(builder.Authority()) + resolver.Deregister(resolverBuilder.Authority()) }) + balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t)) + balancerBuilder.Register() + deps := newDefaultDeps(t, config) deps.Router = router.NewRouter( deps.Logger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter), - builder) + grpc.NewTracker(resolverBuilder, balancerBuilder), + ) client, err := NewClient(config, deps) require.NoError(t, err) t.Cleanup(func() { client.Shutdown() }) - return client, builder + return client, resolverBuilder, balancerBuilder } type testLogger interface { diff --git a/agent/grpc-internal/balancer/balancer.go b/agent/grpc-internal/balancer/balancer.go new file mode 100644 index 0000000000..efd349c82c --- /dev/null +++ b/agent/grpc-internal/balancer/balancer.go @@ -0,0 +1,483 @@ +// package balancer implements a custom gRPC load balancer. +// +// Similarly to gRPC's built-in "pick_first" balancer, our balancer will pin the +// client to a single connection/server. However, it will switch servers as soon +// as an RPC error occurs (e.g. if the client has exhausted its rate limit on +// that server). It also provides a method that will be called periodically by +// the Consul router to randomize the connection priorities to rebalance load. +// +// Our balancer aims to keep exactly one TCP connection (to the current server) +// open at a time. This is different to gRPC's "round_robin" and "base" balancers +// which connect to *all* resolved addresses up-front so that you can quickly +// cycle between them - which we want to avoid because of the overhead on the +// servers. It's also slightly different to gRPC's "pick_first" balancer which +// will attempt to remain connected to the same server as long its address is +// returned by the resolver - we previously had to work around this behavior in +// order to shuffle the servers, which had some unfortunate side effects as +// documented in this issue: https://github.com/hashicorp/consul/issues/10603. +// +// If a server is in a perpetually bad state, the balancer's standard error +// handling will steer away from it but it will *not* be removed from the set +// and will remain in a TRANSIENT_FAILURE state to possibly be retried in the +// future. It is expected that Consul's router will remove servers from the +// resolver which have been network partitioned etc. +// +// Quick primer on how gRPC's different components work together: +// +// - Targets (e.g. consul://.../server.dc1) represent endpoints/collections of +// hosts. They're what you pass as the first argument to grpc.Dial. +// +// - ClientConns represent logical connections to targets. Each ClientConn may +// have many SubConns (and therefore TCP connections to different hosts). +// +// - SubConns represent connections to a single host. They map 1:1 with TCP +// connections (that's actually a bit of a lie, but true for our purposes). +// +// - Resolvers are responsible for turning Targets into sets of addresses (e.g. +// via DNS resolution) and updating the ClientConn when they change. They map +// 1:1 with ClientConns. gRPC creates them for a ClientConn using the builder +// registered for the Target's scheme (i.e. the protocol part of the URL). +// +// - Balancers are responsible for turning resolved addresses into SubConns and +// a Picker. They're called whenever the Resolver updates the ClientConn's +// state (e.g. with new addresses) or when the SubConns change state. +// +// Like Resolvers, they also map 1:1 with ClientConns and are created using a +// builder registered with a name that is specified in the "service config". +// +// - Pickers are responsible for deciding which SubConn will be used for an RPC. +package balancer + +import ( + "container/list" + "errors" + "fmt" + "math/rand" + "sort" + "sync" + "time" + + "github.com/hashicorp/go-hclog" + gbalancer "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/status" +) + +// NewBuilder constructs a new Builder with the given name. +func NewBuilder(name string, logger hclog.Logger) *Builder { + return &Builder{ + name: name, + logger: logger, + byTarget: make(map[string]*list.List), + shuffler: randomShuffler(), + } +} + +// Builder implements gRPC's balancer.Builder interface to construct balancers. +type Builder struct { + name string + logger hclog.Logger + shuffler shuffler + + mu sync.Mutex + byTarget map[string]*list.List +} + +// Build is called by gRPC (e.g. on grpc.Dial) to construct a balancer for the +// given ClientConn. +func (b *Builder) Build(cc gbalancer.ClientConn, opts gbalancer.BuildOptions) gbalancer.Balancer { + b.mu.Lock() + defer b.mu.Unlock() + + targetURL := opts.Target.URL.String() + + logger := b.logger.With("target", targetURL) + logger.Trace("creating balancer") + + bal := newBalancer(cc, opts.Target, logger) + + byTarget, ok := b.byTarget[targetURL] + if !ok { + byTarget = list.New() + b.byTarget[targetURL] = byTarget + } + elem := byTarget.PushBack(bal) + + bal.closeFn = func() { + logger.Trace("removing balancer") + b.removeBalancer(targetURL, elem) + } + + return bal +} + +// removeBalancer is called when a Balancer is closed to remove it from our list. +func (b *Builder) removeBalancer(targetURL string, elem *list.Element) { + b.mu.Lock() + defer b.mu.Unlock() + + byTarget, ok := b.byTarget[targetURL] + if !ok { + return + } + byTarget.Remove(elem) + + if byTarget.Len() == 0 { + delete(b.byTarget, targetURL) + } +} + +// Name implements the gRPC Balancer interface by returning its given name. +func (b *Builder) Name() string { return b.name } + +// gRPC's balancer.Register method is not thread-safe, so we guard our calls +// with a global lock (as it may be called from parallel tests). +var registerLock sync.Mutex + +// Register the Builder in gRPC's global registry using its given name. +func (b *Builder) Register() { + registerLock.Lock() + defer registerLock.Unlock() + + gbalancer.Register(b) +} + +// Rebalance randomizes the priority order of servers for the given target to +// rebalance load. +func (b *Builder) Rebalance(target resolver.Target) { + b.mu.Lock() + defer b.mu.Unlock() + + byTarget, ok := b.byTarget[target.URL.String()] + if !ok { + return + } + + for item := byTarget.Front(); item != nil; item = item.Next() { + item.Value.(*balancer).shuffleServerOrder(b.shuffler) + } +} + +func newBalancer(conn gbalancer.ClientConn, target resolver.Target, logger hclog.Logger) *balancer { + return &balancer{ + conn: conn, + target: target, + logger: logger, + servers: resolver.NewAddressMap(), + } +} + +type balancer struct { + conn gbalancer.ClientConn + target resolver.Target + logger hclog.Logger + closeFn func() + + mu sync.Mutex + subConn gbalancer.SubConn + connState connectivity.State + connError error + currentServer *serverInfo + servers *resolver.AddressMap +} + +type serverInfo struct { + addr resolver.Address + index int // determines the order in which servers will be attempted. + lastFailed time.Time // used to steer away from servers that recently returned errors. +} + +// String returns a log-friendly representation of the server. +func (si *serverInfo) String() string { + if si == nil { + return "" + } + return si.addr.Addr +} + +// Close is called by gRPC when the Balancer is no longer needed (e.g. when the +// ClientConn is closed by the application). +func (b *balancer) Close() { b.closeFn() } + +// ResolverError is called by gRPC when the resolver reports an error. It puts +// the connection into a TRANSIENT_FAILURE state. +func (b *balancer) ResolverError(err error) { + b.mu.Lock() + defer b.mu.Unlock() + + b.logger.Trace("resolver error", "error", err) + b.handleErrorLocked(err) +} + +// UpdateClientConnState is called by gRPC when the ClientConn changes state, +// such as when the resolver produces new addresses. +func (b *balancer) UpdateClientConnState(state gbalancer.ClientConnState) error { + b.mu.Lock() + defer b.mu.Unlock() + + newAddrs := resolver.NewAddressMap() + + // Add any new addresses. + for _, addr := range state.ResolverState.Addresses { + newAddrs.Set(addr, struct{}{}) + + if _, have := b.servers.Get(addr); !have { + b.logger.Trace("adding server address", "address", addr.Addr) + + b.servers.Set(addr, &serverInfo{ + addr: addr, + index: b.servers.Len(), + }) + } + } + + // Delete any addresses that have been removed. + for _, addr := range b.servers.Keys() { + if _, have := newAddrs.Get(addr); !have { + b.logger.Trace("removing server address", "address", addr.Addr) + b.servers.Delete(addr) + } + } + + if b.servers.Len() == 0 { + b.switchServerLocked(nil) + b.handleErrorLocked(errors.New("resolver produced no addresses")) + return gbalancer.ErrBadResolverState + } + + b.maybeSwitchServerLocked() + return nil +} + +// UpdateSubConnState is called by gRPC when a SubConn changes state, such as +// when transitioning from CONNECTING to READY. +func (b *balancer) UpdateSubConnState(sc gbalancer.SubConn, state gbalancer.SubConnState) { + b.mu.Lock() + defer b.mu.Unlock() + + if sc != b.subConn { + return + } + + b.logger.Trace("sub-connection state changed", "server", b.currentServer, "state", state.ConnectivityState) + b.connState = state.ConnectivityState + b.connError = state.ConnectionError + + // Note: it's not clear whether this can actually happen or not. It would mean + // the sub-conn was shut down by something other than us calling RemoveSubConn. + if state.ConnectivityState == connectivity.Shutdown { + b.switchServerLocked(nil) + return + } + + b.updatePickerLocked() +} + +// handleErrorLocked puts the ClientConn into a TRANSIENT_FAILURE state and +// causes the picker to return the given error on Pick. +// +// Note: b.mu must be held when calling this method. +func (b *balancer) handleErrorLocked(err error) { + b.connState = connectivity.TransientFailure + b.connError = fmt.Errorf("resolver error: %w", err) + b.updatePickerLocked() +} + +// maybeSwitchServerLocked switches server if the one we're currently connected +// to is no longer our preference (e.g. based on error state). +// +// Note: b.mu must be held when calling this method. +func (b *balancer) maybeSwitchServerLocked() { + if ideal := b.idealServerLocked(); ideal != b.currentServer { + b.switchServerLocked(ideal) + } +} + +// idealServerLocked determines which server we should currently be connected to +// when taking the error state and rebalance-shuffling into consideration. +// +// Returns nil if there isn't a suitable server. +// +// Note: b.mu must be held when calling this method. +func (b *balancer) idealServerLocked() *serverInfo { + candidates := make([]*serverInfo, b.servers.Len()) + for idx, v := range b.servers.Values() { + candidates[idx] = v.(*serverInfo) + } + + if len(candidates) == 0 { + return nil + } + + sort.Slice(candidates, func(a, b int) bool { + ca, cb := candidates[a], candidates[b] + + return ca.lastFailed.Before(cb.lastFailed) || + (ca.lastFailed.Equal(cb.lastFailed) && ca.index < cb.index) + }) + return candidates[0] +} + +// switchServerLocked switches to the given server, creating a new connection +// and tearing down the previous connection. +// +// It's expected for either/neither/both of b.currentServer and newServer to be nil. +// +// Note: b.mu must be held when calling this method. +func (b *balancer) switchServerLocked(newServer *serverInfo) { + b.logger.Debug("switching server", "from", b.currentServer, "to", newServer) + + prevConn := b.subConn + b.currentServer = newServer + + if newServer == nil { + b.subConn = nil + } else { + var err error + b.subConn, err = b.conn.NewSubConn([]resolver.Address{newServer.addr}, gbalancer.NewSubConnOptions{}) + if err == nil { + b.subConn.Connect() + b.connState = connectivity.Connecting + } else { + b.logger.Trace("failed to create sub-connection", "addr", newServer.addr, "error", err) + b.handleErrorLocked(fmt.Errorf("failed to create sub-connection: %w", err)) + return + } + } + + b.updatePickerLocked() + + if prevConn != nil { + b.conn.RemoveSubConn(prevConn) + } +} + +// updatePickerLocked updates the ClientConn's Picker based on the balancer's +// current state. +// +// Note: b.mu must be held when calling this method. +func (b *balancer) updatePickerLocked() { + var p gbalancer.Picker + switch b.connState { + case connectivity.Connecting: + p = errPicker{err: gbalancer.ErrNoSubConnAvailable} + case connectivity.TransientFailure: + p = errPicker{err: b.connError} + case connectivity.Idle: + p = idlePicker{conn: b.subConn} + case connectivity.Ready: + srv := b.currentServer + + p = readyPicker{ + conn: b.subConn, + errFn: func(err error) { + b.witnessError(srv, err) + }, + } + default: + // Note: shutdown state is handled in UpdateSubConnState. + b.logger.Trace("connection in unexpected state", "state", b.connState) + } + + b.conn.UpdateState(gbalancer.State{ + ConnectivityState: b.connState, + Picker: p, + }) +} + +// witnessError marks the given server as having failed and triggers a switch +// if required. +func (b *balancer) witnessError(server *serverInfo, err error) { + // The following status codes represent errors that probably won't be solved + // by switching servers, so we shouldn't bother disrupting in-flight streams. + switch status.Code(err) { + case codes.Canceled, + codes.InvalidArgument, + codes.NotFound, + codes.AlreadyExists, + codes.PermissionDenied, + codes.Unauthenticated: + return + } + + b.mu.Lock() + defer b.mu.Unlock() + + b.logger.Trace("witnessed RPC error", "server", server, "error", err) + server.lastFailed = time.Now() + b.maybeSwitchServerLocked() +} + +// shuffleServerOrder re-prioritizes the servers using the given shuffler, it +// also unsets the lastFailed timestamp (to prevent us *never* connecting to a +// server that previously failed). +func (b *balancer) shuffleServerOrder(shuffler shuffler) { + b.mu.Lock() + defer b.mu.Unlock() + + b.logger.Trace("shuffling server order") + + addrs := b.servers.Keys() + shuffler(addrs) + + for idx, addr := range addrs { + v, ok := b.servers.Get(addr) + if !ok { + continue + } + + srv := v.(*serverInfo) + srv.index = idx + srv.lastFailed = time.Time{} + } + b.maybeSwitchServerLocked() +} + +// errPicker returns the given error on Pick. +type errPicker struct{ err error } + +func (p errPicker) Pick(gbalancer.PickInfo) (gbalancer.PickResult, error) { + return gbalancer.PickResult{}, p.err +} + +// idlePicker attempts to re-establish the given (idle) connection on Pick. +type idlePicker struct{ conn gbalancer.SubConn } + +func (p idlePicker) Pick(gbalancer.PickInfo) (gbalancer.PickResult, error) { + p.conn.Connect() + return gbalancer.PickResult{}, gbalancer.ErrNoSubConnAvailable +} + +// readyPicker returns the given connection on Pick. errFn will be called if +// the RPC fails (i.e. to switch to another server). +type readyPicker struct { + conn gbalancer.SubConn + errFn func(error) +} + +func (p readyPicker) Pick(info gbalancer.PickInfo) (gbalancer.PickResult, error) { + return gbalancer.PickResult{ + SubConn: p.conn, + Done: func(done gbalancer.DoneInfo) { + if err := done.Err; err != nil { + p.errFn(err) + } + }, + }, nil +} + +// shuffler is used to change the priority order of servers, to rebalance load. +type shuffler func([]resolver.Address) + +// randomShuffler randomizes the priority order. +func randomShuffler() shuffler { + rand := rand.New(rand.NewSource(time.Now().UnixNano())) + + return func(addrs []resolver.Address) { + rand.Shuffle(len(addrs), func(a, b int) { + addrs[a], addrs[b] = addrs[b], addrs[a] + }) + } +} diff --git a/agent/grpc-internal/balancer/balancer_test.go b/agent/grpc-internal/balancer/balancer_test.go new file mode 100644 index 0000000000..830092ab3c --- /dev/null +++ b/agent/grpc-internal/balancer/balancer_test.go @@ -0,0 +1,327 @@ +package balancer + +import ( + "context" + "fmt" + "math/rand" + "net" + "net/url" + "sort" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" + + "github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/sdk/testutil/retry" +) + +func TestBalancer(t *testing.T) { + t.Run("remains pinned to the same server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + server1 := runServer(t, "server1") + server2 := runServer(t, "server2") + + target, _ := stubResolver(t, server1, server2) + + balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t)) + balancerBuilder.Register() + + conn := dial(t, target, balancerBuilder) + client := testservice.NewSimpleClient(conn) + + var serverName string + for i := 0; i < 5; i++ { + rsp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + + if i == 0 { + serverName = rsp.ServerName + } else { + require.Equal(t, serverName, rsp.ServerName) + } + } + + var pinnedServer, otherServer *server + switch serverName { + case server1.name: + pinnedServer, otherServer = server1, server2 + case server2.name: + pinnedServer, otherServer = server2, server1 + } + require.Equal(t, 1, + pinnedServer.openConnections(), + "pinned server should have 1 connection", + ) + require.Zero(t, + otherServer.openConnections(), + "other server should have no connections", + ) + }) + + t.Run("switches server on-error", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + server1 := runServer(t, "server1") + server2 := runServer(t, "server2") + + target, _ := stubResolver(t, server1, server2) + + balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t)) + balancerBuilder.Register() + + conn := dial(t, target, balancerBuilder) + client := testservice.NewSimpleClient(conn) + + // Figure out which server we're talking to now, and which we should switch to. + rsp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + + var initialServer, otherServer *server + switch rsp.ServerName { + case server1.name: + initialServer, otherServer = server1, server2 + case server2.name: + initialServer, otherServer = server2, server1 + } + + // Next request should fail (we don't have retries configured). + initialServer.err = status.Error(codes.ResourceExhausted, "rate limit exceeded") + _, err = client.Something(ctx, &testservice.Req{}) + require.Error(t, err) + + // Following request should succeed (against the other server). + rsp, err = client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + require.Equal(t, otherServer.name, rsp.ServerName) + + retry.Run(t, func(r *retry.R) { + require.Zero(r, + initialServer.openConnections(), + "connection to previous server should have been torn down", + ) + }) + }) + + t.Run("rebalance changes the server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + server1 := runServer(t, "server1") + server2 := runServer(t, "server2") + + target, _ := stubResolver(t, server1, server2) + + balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t)) + balancerBuilder.Register() + + // Provide a custom prioritizer that causes Rebalance to choose whichever + // server didn't get our first request. + var otherServer *server + balancerBuilder.shuffler = func(addrs []resolver.Address) { + sort.Slice(addrs, func(a, b int) bool { + return addrs[a].Addr == otherServer.addr + }) + } + + conn := dial(t, target, balancerBuilder) + client := testservice.NewSimpleClient(conn) + + // Figure out which server we're talking to now. + rsp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + + var initialServer *server + switch rsp.ServerName { + case server1.name: + initialServer, otherServer = server1, server2 + case server2.name: + initialServer, otherServer = server2, server1 + } + + // Trigger a rebalance. + targetURL, err := url.Parse(target) + require.NoError(t, err) + balancerBuilder.Rebalance(resolver.Target{URL: *targetURL}) + + // Following request should hit the other server. + rsp, err = client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + require.Equal(t, otherServer.name, rsp.ServerName) + + retry.Run(t, func(r *retry.R) { + require.Zero(r, + initialServer.openConnections(), + "connection to previous server should have been torn down", + ) + }) + }) + + t.Run("resolver removes the server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + server1 := runServer(t, "server1") + server2 := runServer(t, "server2") + + target, res := stubResolver(t, server1, server2) + + balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t)) + balancerBuilder.Register() + + conn := dial(t, target, balancerBuilder) + client := testservice.NewSimpleClient(conn) + + // Figure out which server we're talking to now. + rsp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + var initialServer, otherServer *server + switch rsp.ServerName { + case server1.name: + initialServer, otherServer = server1, server2 + case server2.name: + initialServer, otherServer = server2, server1 + } + + // Remove the server's address. + res.UpdateState(resolver.State{ + Addresses: []resolver.Address{ + {Addr: otherServer.addr}, + }, + }) + + // Following request should hit the other server. + rsp, err = client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + require.Equal(t, otherServer.name, rsp.ServerName) + + retry.Run(t, func(r *retry.R) { + require.Zero(r, + initialServer.openConnections(), + "connection to previous server should have been torn down", + ) + }) + + // Remove the other server too. + res.UpdateState(resolver.State{ + Addresses: []resolver.Address{}, + }) + + _, err = client.Something(ctx, &testservice.Req{}) + require.Error(t, err) + require.Contains(t, err.Error(), "resolver produced no addresses") + + retry.Run(t, func(r *retry.R) { + require.Zero(r, + otherServer.openConnections(), + "connection to other server should have been torn down", + ) + }) + }) +} + +func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) { + t.Helper() + + addresses := make([]resolver.Address, len(servers)) + for idx, s := range servers { + addresses[idx] = resolver.Address{Addr: s.addr} + } + + scheme := fmt.Sprintf("consul-%d-%d", time.Now().UnixNano(), rand.Int()) + + r := manual.NewBuilderWithScheme(scheme) + r.InitialState(resolver.State{Addresses: addresses}) + + resolver.Register(r) + t.Cleanup(func() { resolver.UnregisterForTesting(scheme) }) + + return fmt.Sprintf("%s://", scheme), r +} + +func runServer(t *testing.T, name string) *server { + t.Helper() + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + s := &server{ + name: name, + addr: lis.Addr().String(), + } + + gs := grpc.NewServer( + grpc.StatsHandler(s), + ) + testservice.RegisterSimpleServer(gs, s) + go gs.Serve(lis) + + var once sync.Once + s.shutdown = func() { once.Do(gs.Stop) } + t.Cleanup(s.shutdown) + + return s +} + +type server struct { + name string + addr string + err error + + c int32 + shutdown func() +} + +func (s *server) openConnections() int { return int(atomic.LoadInt32(&s.c)) } + +func (*server) HandleRPC(context.Context, stats.RPCStats) {} +func (*server) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { return ctx } +func (*server) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx } + +func (s *server) HandleConn(_ context.Context, cs stats.ConnStats) { + switch cs.(type) { + case *stats.ConnBegin: + atomic.AddInt32(&s.c, 1) + case *stats.ConnEnd: + atomic.AddInt32(&s.c, -1) + } +} + +func (*server) Flow(*testservice.Req, testservice.Simple_FlowServer) error { return nil } + +func (s *server) Something(context.Context, *testservice.Req) (*testservice.Resp, error) { + if s.err != nil { + return nil, s.err + } + return &testservice.Resp{ServerName: s.name}, nil +} + +func dial(t *testing.T, target string, builder *Builder) *grpc.ClientConn { + conn, err := grpc.Dial( + target, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig( + fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, builder.Name()), + ), + ) + t.Cleanup(func() { + if err := conn.Close(); err != nil { + t.Logf("error closing connection: %v", err) + } + }) + require.NoError(t, err) + return conn +} diff --git a/agent/grpc-internal/balancer/custombalancer.go b/agent/grpc-internal/balancer/custombalancer.go deleted file mode 100644 index c3c5409d39..0000000000 --- a/agent/grpc-internal/balancer/custombalancer.go +++ /dev/null @@ -1,87 +0,0 @@ -package balancer - -import ( - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/resolver" -) - -func init() { - balancer.Register(newCustomPickfirstBuilder()) -} - -// logger is referenced in pickfirst.go. -// The gRPC library uses the same component name. -var logger = grpclog.Component("balancer") - -func newCustomPickfirstBuilder() balancer.Builder { - return &customPickfirstBuilder{} -} - -type customPickfirstBuilder struct{} - -func (*customPickfirstBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { - return &customPickfirstBalancer{ - pickfirstBalancer: pickfirstBalancer{cc: cc}, - } -} - -func (*customPickfirstBuilder) Name() string { - return "pick_first_custom" -} - -// customPickfirstBalancer overrides UpdateClientConnState of pickfirstBalancer. -type customPickfirstBalancer struct { - pickfirstBalancer - - activeAddr resolver.Address -} - -func (b *customPickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error { - for _, a := range state.ResolverState.Addresses { - // This hack preserves an existing behavior in our client-side - // load balancing where if the first address in a shuffled list - // of addresses matched the currently connected address, it would - // be an effective no-op. - if a.Equal(b.activeAddr) { - break - } - - // Attempt to make a new SubConn with a single address so we can - // track a successful connection explicitly. If we were to pass - // a list of addresses, we cannot assume the first address was - // successful and there is no way to extract the connected address. - sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{}) - if err != nil { - logger.Warningf("balancer.customPickfirstBalancer: failed to create new SubConn: %v", err) - continue - } - - if b.subConn != nil { - b.cc.RemoveSubConn(b.subConn) - } - - // Copy-pasted from pickfirstBalancer.UpdateClientConnState. - { - b.subConn = sc - b.state = connectivity.Idle - b.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.Idle, - Picker: &picker{result: balancer.PickResult{SubConn: b.subConn}}, - }) - b.subConn.Connect() - } - - b.activeAddr = a - - // We now have a new subConn with one address. - // Break the loop and call UpdateClientConnState - // with the full set of addresses. - break - } - - // This will load the full set of addresses but leave the - // newly created subConn alone. - return b.pickfirstBalancer.UpdateClientConnState(state) -} diff --git a/agent/grpc-internal/balancer/pickfirst.go b/agent/grpc-internal/balancer/pickfirst.go deleted file mode 100644 index 45edcddce2..0000000000 --- a/agent/grpc-internal/balancer/pickfirst.go +++ /dev/null @@ -1,189 +0,0 @@ -// NOTICE: This file is a copy of grpc's pick_first implementation [1]. -// It is preserved as-is with the init() removed for easier updating. -// -// [1]: https://github.com/grpc/grpc-go/blob/v1.49.x/pickfirst.go - -/* - * - * Copyright 2017 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package balancer - -import ( - "errors" - "fmt" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/connectivity" -) - -// PickFirstBalancerName is the name of the pick_first balancer. -const PickFirstBalancerName = "pick_first_original" - -func newPickfirstBuilder() balancer.Builder { - return &pickfirstBuilder{} -} - -type pickfirstBuilder struct{} - -func (*pickfirstBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { - return &pickfirstBalancer{cc: cc} -} - -func (*pickfirstBuilder) Name() string { - return PickFirstBalancerName -} - -type pickfirstBalancer struct { - state connectivity.State - cc balancer.ClientConn - subConn balancer.SubConn -} - -func (b *pickfirstBalancer) ResolverError(err error) { - if logger.V(2) { - logger.Infof("pickfirstBalancer: ResolverError called with error %v", err) - } - if b.subConn == nil { - b.state = connectivity.TransientFailure - } - - if b.state != connectivity.TransientFailure { - // The picker will not change since the balancer does not currently - // report an error. - return - } - b.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.TransientFailure, - Picker: &picker{err: fmt.Errorf("name resolver error: %v", err)}, - }) -} - -func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error { - if len(state.ResolverState.Addresses) == 0 { - // The resolver reported an empty address list. Treat it like an error by - // calling b.ResolverError. - if b.subConn != nil { - // Remove the old subConn. All addresses were removed, so it is no longer - // valid. - b.cc.RemoveSubConn(b.subConn) - b.subConn = nil - } - b.ResolverError(errors.New("produced zero addresses")) - return balancer.ErrBadResolverState - } - - if b.subConn != nil { - b.cc.UpdateAddresses(b.subConn, state.ResolverState.Addresses) - return nil - } - - subConn, err := b.cc.NewSubConn(state.ResolverState.Addresses, balancer.NewSubConnOptions{}) - if err != nil { - if logger.V(2) { - logger.Errorf("pickfirstBalancer: failed to NewSubConn: %v", err) - } - b.state = connectivity.TransientFailure - b.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.TransientFailure, - Picker: &picker{err: fmt.Errorf("error creating connection: %v", err)}, - }) - return balancer.ErrBadResolverState - } - b.subConn = subConn - b.state = connectivity.Idle - b.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.Idle, - Picker: &picker{result: balancer.PickResult{SubConn: b.subConn}}, - }) - b.subConn.Connect() - return nil -} - -func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state balancer.SubConnState) { - if logger.V(2) { - logger.Infof("pickfirstBalancer: UpdateSubConnState: %p, %v", subConn, state) - } - if b.subConn != subConn { - if logger.V(2) { - logger.Infof("pickfirstBalancer: ignored state change because subConn is not recognized") - } - return - } - b.state = state.ConnectivityState - if state.ConnectivityState == connectivity.Shutdown { - b.subConn = nil - return - } - - switch state.ConnectivityState { - case connectivity.Ready: - b.cc.UpdateState(balancer.State{ - ConnectivityState: state.ConnectivityState, - Picker: &picker{result: balancer.PickResult{SubConn: subConn}}, - }) - case connectivity.Connecting: - b.cc.UpdateState(balancer.State{ - ConnectivityState: state.ConnectivityState, - Picker: &picker{err: balancer.ErrNoSubConnAvailable}, - }) - case connectivity.Idle: - b.cc.UpdateState(balancer.State{ - ConnectivityState: state.ConnectivityState, - Picker: &idlePicker{subConn: subConn}, - }) - case connectivity.TransientFailure: - b.cc.UpdateState(balancer.State{ - ConnectivityState: state.ConnectivityState, - Picker: &picker{err: state.ConnectionError}, - }) - } -} - -func (b *pickfirstBalancer) Close() { -} - -func (b *pickfirstBalancer) ExitIdle() { - if b.subConn != nil && b.state == connectivity.Idle { - b.subConn.Connect() - } -} - -type picker struct { - result balancer.PickResult - err error -} - -func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) { - return p.result, p.err -} - -// idlePicker is used when the SubConn is IDLE and kicks the SubConn into -// CONNECTING when Pick is called. -type idlePicker struct { - subConn balancer.SubConn -} - -func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) { - i.subConn.Connect() - return balancer.PickResult{}, balancer.ErrNoSubConnAvailable -} - -// Intentionally removed -// func init() { -// balancer.Register(newPickfirstBuilder()) -// } diff --git a/agent/grpc-internal/client.go b/agent/grpc-internal/client.go index 9a1e8402a7..a0e44e4153 100644 --- a/agent/grpc-internal/client.go +++ b/agent/grpc-internal/client.go @@ -8,26 +8,62 @@ import ( "time" "google.golang.org/grpc" + gbalancer "google.golang.org/grpc/balancer" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" "github.com/armon/go-metrics" - _ "github.com/hashicorp/consul/agent/grpc-internal/balancer" - agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/tlsutil" ) +// grpcServiceConfig is provided as the default service config. +// +// It configures our custom balancer (via the %s directive to interpolate its +// name) which will automatically switch servers on error. +// +// It also enables gRPC's built-in automatic retries for RESOURCE_EXHAUSTED +// errors *only*, as this is the status code servers will return for an +// operation that failed due to an exhausted rate limit that might succeed +// against a different server (i.e. the one the balancer just switched to). +// +// Note: the empty object in methodConfig.name is what enables retries for all +// services/methods. +// +// See: +// - https://github.com/grpc/grpc/blob/master/doc/service_config.md +// - https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto +// +// TODO(boxofrad): we can use the rate limit annotations to figure out which +// methods are reads (and therefore safe to retry whatever the status code). +const grpcServiceConfig = ` +{ + "loadBalancingConfig": [{"%s":{}}], + "methodConfig": [ + { + "name": [{}], + "retryPolicy": { + "MaxAttempts": 5, + "BackoffMultiplier": 2, + "InitialBackoff": "1s", + "MaxBackoff": "5s", + "RetryableStatusCodes": ["RESOURCE_EXHAUSTED"] + } + } + ] +}` + // ClientConnPool creates and stores a connection for each datacenter. type ClientConnPool struct { - dialer dialer - servers ServerLocator - gwResolverDep gatewayResolverDep - conns map[string]*grpc.ClientConn - connsLock sync.Mutex + dialer dialer + servers ServerLocator + gwResolverDep gatewayResolverDep + conns map[string]*grpc.ClientConn + connsLock sync.Mutex + balancerBuilder gbalancer.Builder } type ServerLocator interface { @@ -89,14 +125,21 @@ type ClientConnPoolConfig struct { // DialingFromDatacenter is the datacenter of the consul agent using this // pool. DialingFromDatacenter string + + // BalancerBuilder is a builder for the gRPC balancer that will be used. + BalancerBuilder gbalancer.Builder } // NewClientConnPool create new GRPC client pool to connect to servers using // GRPC over RPC. func NewClientConnPool(cfg ClientConnPoolConfig) *ClientConnPool { + if cfg.BalancerBuilder == nil { + panic("missing required BalancerBuilder") + } c := &ClientConnPool{ - servers: cfg.Servers, - conns: make(map[string]*grpc.ClientConn), + servers: cfg.Servers, + conns: make(map[string]*grpc.ClientConn), + balancerBuilder: cfg.BalancerBuilder, } c.dialer = newDialer(cfg, &c.gwResolverDep) return c @@ -134,9 +177,10 @@ func (c *ClientConnPool) dial(datacenter string, serverType string) (*grpc.Clien // custom dialer based on logic around whether the server has TLS enabled. grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(c.dialer), - grpc.WithDisableRetry(), grpc.WithStatsHandler(agentmiddleware.NewStatsHandler(metrics.Default(), metricsLabels)), - grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"pick_first_custom"}`), + grpc.WithDefaultServiceConfig( + fmt.Sprintf(grpcServiceConfig, c.balancerBuilder.Name()), + ), // Keep alive parameters are based on the same default ones we used for // Yamux. These are somewhat arbitrary but we did observe in scale testing // that the gRPC defaults (servers send keepalives only every 2 hours, diff --git a/agent/grpc-internal/client_test.go b/agent/grpc-internal/client_test.go index d9d264d803..ebd0601ad4 100644 --- a/agent/grpc-internal/client_test.go +++ b/agent/grpc-internal/client_test.go @@ -13,12 +13,15 @@ import ( "github.com/hashicorp/go-hclog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + gbalancer "google.golang.org/grpc/balancer" + "github.com/hashicorp/consul/agent/grpc-internal/balancer" "github.com/hashicorp/consul/agent/grpc-internal/resolver" "github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/ipaddr" "github.com/hashicorp/consul/sdk/freeport" + "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/types" ) @@ -165,6 +168,7 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { UseTLSForDC: tlsConf.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc1", + BalancerBuilder: balancerBuilder(t, res.Authority()), }) conn, err := pool.ClientConn("dc1") @@ -240,6 +244,7 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) UseTLSForDC: tlsConf.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc2", + BalancerBuilder: balancerBuilder(t, res.Authority()), }) pool.SetGatewayResolver(func(addr string) string { return gwAddr @@ -268,6 +273,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", + BalancerBuilder: balancerBuilder(t, res.Authority()), }) for i := 0; i < count; i++ { @@ -303,6 +309,7 @@ func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) { UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", + BalancerBuilder: balancerBuilder(t, res.Authority()), }) var servers []testServer @@ -345,59 +352,6 @@ func newConfig(t *testing.T) resolver.Config { return resolver.Config{Authority: strings.ToLower(s)} } -func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) { - count := 5 - res := resolver.NewServerResolverBuilder(newConfig(t)) - registerWithGRPC(t, res) - pool := NewClientConnPool(ClientConnPoolConfig{ - Servers: res, - UseTLSForDC: useTLSForDcAlwaysTrue, - DialingFromServer: true, - DialingFromDatacenter: "dc1", - }) - - for i := 0; i < count; i++ { - name := fmt.Sprintf("server-%d", i) - srv := newSimpleTestServer(t, name, "dc1", nil) - res.AddServer(types.AreaWAN, srv.Metadata()) - t.Cleanup(srv.shutdown) - } - - conn, err := pool.ClientConn("dc1") - require.NoError(t, err) - client := testservice.NewSimpleClient(conn) - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - t.Cleanup(cancel) - - first, err := client.Something(ctx, &testservice.Req{}) - require.NoError(t, err) - - t.Run("rebalance a different DC, does nothing", func(t *testing.T) { - res.NewRebalancer("dc-other")() - - resp, err := client.Something(ctx, &testservice.Req{}) - require.NoError(t, err) - require.Equal(t, resp.ServerName, first.ServerName) - }) - - t.Run("rebalance the dc", func(t *testing.T) { - // Rebalance is random, but if we repeat it a few times it should give us a - // new server. - attempts := 100 - for i := 0; i < attempts; i++ { - res.NewRebalancer("dc1")() - - resp, err := client.Something(ctx, &testservice.Req{}) - require.NoError(t, err) - if resp.ServerName != first.ServerName { - return - } - } - t.Fatalf("server was not rebalanced after %v attempts", attempts) - }) -} - func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { dcs := []string{"dc1", "dc2", "dc3"} @@ -408,6 +362,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", + BalancerBuilder: balancerBuilder(t, res.Authority()), }) for _, dc := range dcs { @@ -437,3 +392,12 @@ func registerWithGRPC(t *testing.T, b *resolver.ServerResolverBuilder) { resolver.Deregister(b.Authority()) }) } + +func balancerBuilder(t *testing.T, name string) gbalancer.Builder { + t.Helper() + + bb := balancer.NewBuilder(name, testutil.Logger(t)) + bb.Register() + + return bb +} diff --git a/agent/grpc-internal/handler_test.go b/agent/grpc-internal/handler_test.go index 4f093ac65e..96f6f036e0 100644 --- a/agent/grpc-internal/handler_test.go +++ b/agent/grpc-internal/handler_test.go @@ -38,6 +38,7 @@ func TestHandler_PanicRecoveryInterceptor(t *testing.T) { UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", + BalancerBuilder: balancerBuilder(t, res.Authority()), }) conn, err := pool.ClientConn("dc1") diff --git a/agent/grpc-internal/resolver/resolver.go b/agent/grpc-internal/resolver/resolver.go index 87275449ef..b5b76de6ef 100644 --- a/agent/grpc-internal/resolver/resolver.go +++ b/agent/grpc-internal/resolver/resolver.go @@ -2,10 +2,8 @@ package resolver import ( "fmt" - "math/rand" "strings" "sync" - "time" "google.golang.org/grpc/resolver" @@ -43,31 +41,6 @@ func NewServerResolverBuilder(cfg Config) *ServerResolverBuilder { } } -// NewRebalancer returns a function which shuffles the server list for resolvers -// in all datacenters. -func (s *ServerResolverBuilder) NewRebalancer(dc string) func() { - shuffler := rand.New(rand.NewSource(time.Now().UnixNano())) - return func() { - s.lock.RLock() - defer s.lock.RUnlock() - - for _, resolver := range s.resolvers { - if resolver.datacenter != dc { - continue - } - // Shuffle the list of addresses using the last list given to the resolver. - resolver.addrLock.Lock() - addrs := resolver.addrs - shuffler.Shuffle(len(addrs), func(i, j int) { - addrs[i], addrs[j] = addrs[j], addrs[i] - }) - // Pass the shuffled list to the resolver. - resolver.updateAddrsLocked(addrs) - resolver.addrLock.Unlock() - } - } -} - // ServerForGlobalAddr returns server metadata for a server with the specified globally unique address. func (s *ServerResolverBuilder) ServerForGlobalAddr(globalAddr string) (*metadata.Server, error) { s.lock.RLock() @@ -265,12 +238,7 @@ var _ resolver.Resolver = (*serverResolver)(nil) func (r *serverResolver) updateAddrs(addrs []resolver.Address) { r.addrLock.Lock() defer r.addrLock.Unlock() - r.updateAddrsLocked(addrs) -} -// updateAddrsLocked updates this serverResolver's ClientConn to use the given -// set of addrs. addrLock must be held by caller. -func (r *serverResolver) updateAddrsLocked(addrs []resolver.Address) { r.clientConn.UpdateState(resolver.State{Addresses: addrs}) r.addrs = addrs } diff --git a/agent/grpc-internal/tracker.go b/agent/grpc-internal/tracker.go new file mode 100644 index 0000000000..779e116e8b --- /dev/null +++ b/agent/grpc-internal/tracker.go @@ -0,0 +1,46 @@ +package internal + +import ( + "fmt" + "net/url" + + gresolver "google.golang.org/grpc/resolver" + + "github.com/hashicorp/consul/agent/grpc-internal/balancer" + "github.com/hashicorp/consul/agent/grpc-internal/resolver" + "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/types" +) + +// NewTracker returns an implementation of the router.ServerTracker interface +// backed by the given ServerResolverBuilder and Balancer. +func NewTracker(rb *resolver.ServerResolverBuilder, bb *balancer.Builder) *Tracker { + return &Tracker{rb, bb} +} + +// Tracker satisfies the ServerTracker interface the router manager uses to +// register/deregister servers and trigger rebalances. +type Tracker struct { + rb *resolver.ServerResolverBuilder + bb *balancer.Builder +} + +// AddServer adds the given server to the resolver. +func (t *Tracker) AddServer(a types.AreaID, s *metadata.Server) { t.rb.AddServer(a, s) } + +// RemoveServer removes the given server from the resolver. +func (t *Tracker) RemoveServer(a types.AreaID, s *metadata.Server) { t.rb.RemoveServer(a, s) } + +// NewRebalancer returns a function that can be called to randomize the +// priority order of connections for the given datacenter. +func (t *Tracker) NewRebalancer(dc string) func() { + return func() { + t.bb.Rebalance(gresolver.Target{ + URL: url.URL{ + Scheme: "consul", + Host: t.rb.Authority(), + Path: fmt.Sprintf("server.%s", dc), + }, + }) + } +} diff --git a/agent/rpc/peering/service_test.go b/agent/rpc/peering/service_test.go index 7d7cccc3b2..207709ac1f 100644 --- a/agent/rpc/peering/service_test.go +++ b/agent/rpc/peering/service_test.go @@ -28,6 +28,7 @@ import ( external "github.com/hashicorp/consul/agent/grpc-external" "github.com/hashicorp/consul/agent/grpc-external/limiter" grpc "github.com/hashicorp/consul/agent/grpc-internal" + "github.com/hashicorp/consul/agent/grpc-internal/balancer" "github.com/hashicorp/consul/agent/grpc-internal/resolver" agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware" "github.com/hashicorp/consul/agent/pool" @@ -1692,6 +1693,9 @@ func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps { Datacenter: c.Datacenter, } + balancerBuilder := balancer.NewBuilder(t.Name(), testutil.Logger(t)) + balancerBuilder.Register() + return consul.Deps{ EventPublisher: stream.NewEventPublisher(10 * time.Second), Logger: logger, @@ -1705,6 +1709,7 @@ func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps { UseTLSForDC: tls.UseTLS, DialingFromServer: true, DialingFromDatacenter: c.Datacenter, + BalancerBuilder: balancerBuilder, }), LeaderForwarder: builder, EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c), diff --git a/agent/setup.go b/agent/setup.go index b014996dfa..f5e0e8981d 100644 --- a/agent/setup.go +++ b/agent/setup.go @@ -21,6 +21,7 @@ import ( "github.com/hashicorp/consul/agent/consul/xdscapacity" "github.com/hashicorp/consul/agent/grpc-external/limiter" grpcInt "github.com/hashicorp/consul/agent/grpc-internal" + "github.com/hashicorp/consul/agent/grpc-internal/balancer" "github.com/hashicorp/consul/agent/grpc-internal/resolver" grpcWare "github.com/hashicorp/consul/agent/grpc-middleware" "github.com/hashicorp/consul/agent/hcp" @@ -111,25 +112,40 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl d.ViewStore = submatview.NewStore(d.Logger.Named("viewstore")) d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator) - builder := resolver.NewServerResolverBuilder(resolver.Config{ + resolverBuilder := resolver.NewServerResolverBuilder(resolver.Config{ // Set the authority to something sufficiently unique so any usage in // tests would be self-isolating in the global resolver map, while also // not incurring a huge penalty for non-test code. Authority: cfg.Datacenter + "." + string(cfg.NodeID), }) - resolver.Register(builder) + resolver.Register(resolverBuilder) + + balancerBuilder := balancer.NewBuilder( + // Balancer name doesn't really matter, we set it to the resolver authority + // to keep it unique for tests. + resolverBuilder.Authority(), + d.Logger.Named("grpc.balancer"), + ) + balancerBuilder.Register() + d.GRPCConnPool = grpcInt.NewClientConnPool(grpcInt.ClientConnPoolConfig{ - Servers: builder, + Servers: resolverBuilder, SrcAddr: d.ConnPool.SrcAddr, TLSWrapper: grpcInt.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()), ALPNWrapper: grpcInt.ALPNWrapper(d.TLSConfigurator.OutgoingALPNRPCWrapper()), UseTLSForDC: d.TLSConfigurator.UseTLS, DialingFromServer: cfg.ServerMode, DialingFromDatacenter: cfg.Datacenter, + BalancerBuilder: balancerBuilder, }) - d.LeaderForwarder = builder + d.LeaderForwarder = resolverBuilder - d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder) + d.Router = router.NewRouter( + d.Logger, + cfg.Datacenter, + fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), + grpcInt.NewTracker(resolverBuilder, balancerBuilder), + ) // this needs to happen prior to creating auto-config as some of the dependencies // must also be passed to auto-config