From f936ca5aea6a745dc1f649fa53180cd65cdd6e7a Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Wed, 9 Sep 2020 16:37:43 -0400 Subject: [PATCH 1/8] grpc: client conn pool and resolver Extracted from 936522a13c07e8b732b6fde61bba23d05f7b9a70 Co-authored-by: Paul Banks --- agent/consul/grpc_client.go | 117 +++++++++++++ agent/consul/grpc_resolver.go | 240 ++++++++++++++++++++++++++ agent/router/grpc.go | 20 +++ agent/router/manager.go | 32 +++- agent/router/manager_internal_test.go | 6 +- agent/router/manager_test.go | 8 +- agent/router/router.go | 22 ++- agent/router/router_test.go | 2 +- 8 files changed, 426 insertions(+), 21 deletions(-) create mode 100644 agent/consul/grpc_client.go create mode 100644 agent/consul/grpc_resolver.go create mode 100644 agent/router/grpc.go diff --git a/agent/consul/grpc_client.go b/agent/consul/grpc_client.go new file mode 100644 index 0000000000..6e6d3df115 --- /dev/null +++ b/agent/consul/grpc_client.go @@ -0,0 +1,117 @@ +package consul + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/hashicorp/go-hclog" + "google.golang.org/grpc" + + "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/agent/pool" + "github.com/hashicorp/consul/tlsutil" +) + +type ServerProvider interface { + Servers() []*metadata.Server +} + +type GRPCClient struct { + scheme string + serverProvider ServerProvider + tlsConfigurator *tlsutil.Configurator + grpcConns map[string]*grpc.ClientConn + grpcConnLock sync.Mutex +} + +func NewGRPCClient(logger hclog.Logger, serverProvider ServerProvider, tlsConfigurator *tlsutil.Configurator, scheme string) *GRPCClient { + // Note we don't actually use the logger anywhere yet but I guess it was added + // for future compatibility... + return &GRPCClient{ + scheme: scheme, + serverProvider: serverProvider, + tlsConfigurator: tlsConfigurator, + grpcConns: make(map[string]*grpc.ClientConn), + } +} + +func (c *GRPCClient) GRPCConn(datacenter string) (*grpc.ClientConn, error) { + c.grpcConnLock.Lock() + defer c.grpcConnLock.Unlock() + + // If there's an existing ClientConn for the given DC, return it. + if conn, ok := c.grpcConns[datacenter]; ok { + return conn, nil + } + + dialer := newDialer(c.serverProvider, c.tlsConfigurator.OutgoingRPCWrapper()) + conn, err := grpc.Dial(fmt.Sprintf("%s:///server.%s", c.scheme, datacenter), + // use WithInsecure mode here because we handle the TLS wrapping in the + // custom dialer based on logic around whether the server has TLS enabled. + grpc.WithInsecure(), + grpc.WithContextDialer(dialer), + grpc.WithDisableRetry(), + grpc.WithStatsHandler(grpcStatsHandler), + grpc.WithBalancerName("pick_first")) + if err != nil { + return nil, err + } + + c.grpcConns[datacenter] = conn + + return conn, nil +} + +// newDialer returns a gRPC dialer function that conditionally wraps the connection +// with TLS depending on the given useTLS value. +func newDialer(serverProvider ServerProvider, wrapper tlsutil.DCWrapper) func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, addr string) (net.Conn, error) { + d := net.Dialer{} + conn, err := d.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + + // Check if TLS is enabled for the server. + var found bool + var server *metadata.Server + for _, s := range serverProvider.Servers() { + if s.Addr.String() == addr { + found = true + server = s + } + } + if !found { + return nil, fmt.Errorf("could not find Consul server for address %q", addr) + } + + if server.UseTLS { + if wrapper == nil { + return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper") + } + + // Switch the connection into TLS mode + if _, err := conn.Write([]byte{byte(pool.RPCTLS)}); err != nil { + conn.Close() + return nil, err + } + + // Wrap the connection in a TLS client + tlsConn, err := wrapper(server.Datacenter, conn) + if err != nil { + conn.Close() + return nil, err + } + conn = tlsConn + } + + _, err = conn.Write([]byte{pool.RPCGRPC}) + if err != nil { + return nil, err + } + + return conn, nil + } +} diff --git a/agent/consul/grpc_resolver.go b/agent/consul/grpc_resolver.go new file mode 100644 index 0000000000..883bf4e24d --- /dev/null +++ b/agent/consul/grpc_resolver.go @@ -0,0 +1,240 @@ +package consul + +import ( + "math/rand" + "strings" + "sync" + "time" + + "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/agent/router" + "github.com/hashicorp/serf/serf" + "google.golang.org/grpc/resolver" +) + +var registerLock sync.Mutex + +// registerResolverBuilder registers our custom grpc resolver with the given scheme. +func registerResolverBuilder(scheme, datacenter string, shutdownCh <-chan struct{}) *ServerResolverBuilder { + registerLock.Lock() + defer registerLock.Unlock() + grpcResolverBuilder := NewServerResolverBuilder(scheme, datacenter, shutdownCh) + resolver.Register(grpcResolverBuilder) + return grpcResolverBuilder +} + +// ServerResolverBuilder tracks the current server list and keeps any +// ServerResolvers updated when changes occur. +type ServerResolverBuilder struct { + // Allow overriding the scheme to support parallel tests, since + // the resolver builder is registered globally. + scheme string + datacenter string + servers map[string]*metadata.Server + resolvers map[resolver.ClientConn]*ServerResolver + shutdownCh <-chan struct{} + lock sync.Mutex +} + +func NewServerResolverBuilder(scheme, datacenter string, shutdownCh <-chan struct{}) *ServerResolverBuilder { + return &ServerResolverBuilder{ + scheme: scheme, + datacenter: datacenter, + servers: make(map[string]*metadata.Server), + resolvers: make(map[resolver.ClientConn]*ServerResolver), + } +} + +// periodicServerRebalance periodically reshuffles the order of server addresses +// within the resolvers to ensure the load is balanced across servers. +func (s *ServerResolverBuilder) periodicServerRebalance(serf *serf.Serf) { + // Compute the rebalance timer based on the number of local servers and nodes. + rebalanceDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), serf.NumNodes()) + timer := time.NewTimer(rebalanceDuration) + + for { + select { + case <-timer.C: + s.rebalanceResolvers() + + // Re-compute the wait duration. + newTimerDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), serf.NumNodes()) + timer.Reset(newTimerDuration) + case <-s.shutdownCh: + timer.Stop() + return + } + } +} + +// rebalanceResolvers shuffles the server list for resolvers in all datacenters. +func (s *ServerResolverBuilder) rebalanceResolvers() { + s.lock.Lock() + defer s.lock.Unlock() + + for _, resolver := range s.resolvers { + // Shuffle the list of addresses using the last list given to the resolver. + resolver.addrLock.Lock() + addrs := resolver.lastAddrs + rand.Shuffle(len(addrs), func(i, j int) { + addrs[i], addrs[j] = addrs[j], addrs[i] + }) + // Pass the shuffled list to the resolver. + resolver.updateAddrsLocked(addrs) + resolver.addrLock.Unlock() + } +} + +// serversInDC returns the number of servers in the given datacenter. +func (s *ServerResolverBuilder) serversInDC(dc string) int { + s.lock.Lock() + defer s.lock.Unlock() + + var serverCount int + for _, server := range s.servers { + if server.Datacenter == dc { + serverCount++ + } + } + + return serverCount +} + +// Servers returns metadata for all currently known servers. This is used +// by grpc.ClientConn through our custom dialer. +func (s *ServerResolverBuilder) Servers() []*metadata.Server { + s.lock.Lock() + defer s.lock.Unlock() + + servers := make([]*metadata.Server, 0, len(s.servers)) + for _, server := range s.servers { + servers = append(servers, server) + } + return servers +} + +// Build returns a new ServerResolver for the given ClientConn. The resolver +// will keep the ClientConn's state updated based on updates from Serf. +func (s *ServerResolverBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) { + s.lock.Lock() + defer s.lock.Unlock() + + // If there's already a resolver for this datacenter, return it. + datacenter := strings.TrimPrefix(target.Endpoint, "server.") + if resolver, ok := s.resolvers[cc]; ok { + return resolver, nil + } + + // Make a new resolver for the dc and add it to the list of active ones. + resolver := &ServerResolver{ + datacenter: datacenter, + clientConn: cc, + } + resolver.updateAddrs(s.getDCAddrs(datacenter)) + resolver.closeCallback = func() { + s.lock.Lock() + defer s.lock.Unlock() + delete(s.resolvers, cc) + } + + s.resolvers[cc] = resolver + + return resolver, nil +} + +func (s *ServerResolverBuilder) Scheme() string { return s.scheme } + +// AddServer updates the resolvers' states to include the new server's address. +func (s *ServerResolverBuilder) AddServer(server *metadata.Server) { + s.lock.Lock() + defer s.lock.Unlock() + + s.servers[server.ID] = server + + addrs := s.getDCAddrs(server.Datacenter) + for _, resolver := range s.resolvers { + if resolver.datacenter == server.Datacenter { + resolver.updateAddrs(addrs) + } + } +} + +// RemoveServer updates the resolvers' states with the given server removed. +func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) { + s.lock.Lock() + defer s.lock.Unlock() + + delete(s.servers, server.ID) + + addrs := s.getDCAddrs(server.Datacenter) + for _, resolver := range s.resolvers { + if resolver.datacenter == server.Datacenter { + resolver.updateAddrs(addrs) + } + } +} + +// getDCAddrs returns a list of the server addresses for the given datacenter. +// This method assumes the lock is held. +func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address { + var addrs []resolver.Address + for _, server := range s.servers { + if server.Datacenter != dc { + continue + } + + addrs = append(addrs, resolver.Address{ + Addr: server.Addr.String(), + Type: resolver.Backend, + ServerName: server.Name, + }) + } + return addrs +} + +// ServerResolver is a grpc Resolver that will keep a grpc.ClientConn up to date +// on the list of server addresses to use. +type ServerResolver struct { + datacenter string + clientConn resolver.ClientConn + closeCallback func() + + lastAddrs []resolver.Address + addrLock sync.Mutex +} + +// updateAddrs updates this ServerResolver's ClientConn to use the given set of +// addrs. +func (r *ServerResolver) updateAddrs(addrs []resolver.Address) { + r.addrLock.Lock() + defer r.addrLock.Unlock() + r.updateAddrsLocked(addrs) +} + +// updateAddrsLocked updates this ServerResolver's ClientConn to use the given +// set of addrs. addrLock must be held by calleer. +func (r *ServerResolver) updateAddrsLocked(addrs []resolver.Address) { + // Only pass the first address initially, which will cause the + // balancer to spin down the connection for its previous first address + // if it is different. If we don't do this, it will keep using the old + // first address as long as it is still in the list, making it impossible to + // rebalance until that address is removed. + var firstAddr []resolver.Address + if len(addrs) > 0 { + firstAddr = []resolver.Address{addrs[0]} + } + r.clientConn.UpdateState(resolver.State{Addresses: firstAddr}) + + // Call UpdateState again with the entire list of addrs in case we need them + // for failover. + r.clientConn.UpdateState(resolver.State{Addresses: addrs}) + + r.lastAddrs = addrs +} + +func (s *ServerResolver) Close() { + s.closeCallback() +} + +// Unneeded since we only update the ClientConn when our server list changes. +func (*ServerResolver) ResolveNow(o resolver.ResolveNowOption) {} diff --git a/agent/router/grpc.go b/agent/router/grpc.go new file mode 100644 index 0000000000..0a50992811 --- /dev/null +++ b/agent/router/grpc.go @@ -0,0 +1,20 @@ +package router + +import "github.com/hashicorp/consul/agent/metadata" + +// ServerTracker is a wrapper around consul.ServerResolverBuilder to prevent a +// cyclic import dependency. +type ServerTracker interface { + AddServer(*metadata.Server) + RemoveServer(*metadata.Server) +} + +// NoOpServerTracker is a ServerTracker that does nothing. Used when gRPC is not +// enabled. +type NoOpServerTracker struct{} + +// AddServer implements ServerTracker +func (NoOpServerTracker) AddServer(*metadata.Server) {} + +// RemoveServer implements ServerTracker +func (NoOpServerTracker) RemoveServer(*metadata.Server) {} diff --git a/agent/router/manager.go b/agent/router/manager.go index 9c7d805976..0e64c5c904 100644 --- a/agent/router/manager.go +++ b/agent/router/manager.go @@ -98,6 +98,10 @@ type Manager struct { // client.ConnPool. connPoolPinger Pinger + // grpcServerTracker is used to balance grpc connections across servers, + // and has callbacks for adding or removing a server. + grpcServerTracker ServerTracker + // serverName has the name of the managers's server. This is used to // short-circuit pinging to itself. serverName string @@ -119,6 +123,7 @@ type Manager struct { func (m *Manager) AddServer(s *metadata.Server) { m.listLock.Lock() defer m.listLock.Unlock() + m.grpcServerTracker.AddServer(s) l := m.getServerList() // Check if this server is known @@ -251,6 +256,11 @@ func (m *Manager) CheckServers(fn func(srv *metadata.Server) bool) { _ = m.checkServers(fn) } +// Servers returns the current list of servers. +func (m *Manager) Servers() []*metadata.Server { + return m.getServerList().servers +} + // getServerList is a convenience method which hides the locking semantics // of atomic.Value from the caller. func (m *Manager) getServerList() serverList { @@ -267,15 +277,19 @@ func (m *Manager) saveServerList(l serverList) { } // New is the only way to safely create a new Manager struct. -func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, serverName string) (m *Manager) { +func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, tracker ServerTracker, serverName string) (m *Manager) { if logger == nil { logger = hclog.New(&hclog.LoggerOptions{}) } + if tracker == nil { + tracker = NoOpServerTracker{} + } m = new(Manager) m.logger = logger.Named(logging.Manager) m.clusterInfo = clusterInfo // can't pass *consul.Client: import cycle m.connPoolPinger = connPoolPinger // can't pass *consul.ConnPool: import cycle + m.grpcServerTracker = tracker m.rebalanceTimer = time.NewTimer(clientRPCMinReuseDuration) m.shutdownCh = shutdownCh m.serverName = serverName @@ -478,6 +492,7 @@ func (m *Manager) reconcileServerList(l *serverList) bool { func (m *Manager) RemoveServer(s *metadata.Server) { m.listLock.Lock() defer m.listLock.Unlock() + m.grpcServerTracker.RemoveServer(s) l := m.getServerList() // Remove the server if known @@ -498,17 +513,22 @@ func (m *Manager) RemoveServer(s *metadata.Server) { func (m *Manager) refreshServerRebalanceTimer() time.Duration { l := m.getServerList() numServers := len(l.servers) + connRebalanceTimeout := ComputeRebalanceTimer(numServers, m.clusterInfo.NumNodes()) + + m.rebalanceTimer.Reset(connRebalanceTimeout) + return connRebalanceTimeout +} + +// ComputeRebalanceTimer returns a time to wait before rebalancing connections given +// a number of servers and LAN nodes. +func ComputeRebalanceTimer(numServers, numLANMembers int) time.Duration { // Limit this connection's life based on the size (and health) of the // cluster. Never rebalance a connection more frequently than // connReuseLowWatermarkDuration, and make sure we never exceed // clusterWideRebalanceConnsPerSec operations/s across numLANMembers. clusterWideRebalanceConnsPerSec := float64(numServers * newRebalanceConnsPerSecPerServer) connReuseLowWatermarkDuration := clientRPCMinReuseDuration + lib.RandomStagger(clientRPCMinReuseDuration/clientRPCJitterFraction) - numLANMembers := m.clusterInfo.NumNodes() - connRebalanceTimeout := lib.RateScaledInterval(clusterWideRebalanceConnsPerSec, connReuseLowWatermarkDuration, numLANMembers) - - m.rebalanceTimer.Reset(connRebalanceTimeout) - return connRebalanceTimeout + return lib.RateScaledInterval(clusterWideRebalanceConnsPerSec, connReuseLowWatermarkDuration, numLANMembers) } // ResetRebalanceTimer resets the rebalance timer. This method exists for diff --git a/agent/router/manager_internal_test.go b/agent/router/manager_internal_test.go index 76d9512168..63838f4972 100644 --- a/agent/router/manager_internal_test.go +++ b/agent/router/manager_internal_test.go @@ -54,14 +54,14 @@ func (s *fauxSerf) NumNodes() int { func testManager() (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, "") + m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, nil, "") return m } func testManagerFailProb(failPct float64) (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "") + m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, nil, "") return m } @@ -300,7 +300,7 @@ func TestManagerInternal_refreshServerRebalanceTimer(t *testing.T) { shutdownCh := make(chan struct{}) for _, s := range clusters { - m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, "") + m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, nil, "") for i := 0; i < s.numServers; i++ { nodeName := fmt.Sprintf("s%02d", i) m.AddServer(&metadata.Server{Name: nodeName}) diff --git a/agent/router/manager_test.go b/agent/router/manager_test.go index c7e1f299ca..6c3a83816c 100644 --- a/agent/router/manager_test.go +++ b/agent/router/manager_test.go @@ -57,21 +57,21 @@ func (s *fauxSerf) NumNodes() int { func testManager(t testing.TB) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, nil, "") return m } func testManagerFailProb(t testing.TB, failPct float64) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, nil, "") return m } func testManagerFailAddr(t testing.TB, failAddr net.Addr) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, nil, "") return m } @@ -195,7 +195,7 @@ func TestServers_FindServer(t *testing.T) { func TestServers_New(t *testing.T) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "") + m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, nil, "") if m == nil { t.Fatalf("Manager nil") } diff --git a/agent/router/router.go b/agent/router/router.go index 027303bea2..63f6e88ee1 100644 --- a/agent/router/router.go +++ b/agent/router/router.go @@ -41,6 +41,10 @@ type Router struct { // routeFn is a hook to actually do the routing. routeFn func(datacenter string) (*Manager, *metadata.Server, bool) + // grpcServerTracker is used to balance grpc connections across servers, + // and has callbacks for adding or removing a server. + grpcServerTracker ServerTracker + // isShutdown prevents adding new routes to a router after it is shut // down. isShutdown bool @@ -87,17 +91,21 @@ type areaInfo struct { } // NewRouter returns a new Router with the given configuration. -func NewRouter(logger hclog.Logger, localDatacenter, serverName string) *Router { +func NewRouter(logger hclog.Logger, localDatacenter, serverName string, tracker ServerTracker) *Router { if logger == nil { logger = hclog.New(&hclog.LoggerOptions{}) } + if tracker == nil { + tracker = NoOpServerTracker{} + } router := &Router{ - logger: logger.Named(logging.Router), - localDatacenter: localDatacenter, - serverName: serverName, - areas: make(map[types.AreaID]*areaInfo), - managers: make(map[string][]*Manager), + logger: logger.Named(logging.Router), + localDatacenter: localDatacenter, + serverName: serverName, + areas: make(map[types.AreaID]*areaInfo), + managers: make(map[string][]*Manager), + grpcServerTracker: tracker, } // Hook the direct route lookup by default. @@ -251,7 +259,7 @@ func (r *Router) maybeInitializeManager(area *areaInfo, dc string) *Manager { } shutdownCh := make(chan struct{}) - manager := New(r.logger, shutdownCh, area.cluster, area.pinger, r.serverName) + manager := New(r.logger, shutdownCh, area.cluster, area.pinger, nil, r.serverName) info = &managerInfo{ manager: manager, shutdownCh: shutdownCh, diff --git a/agent/router/router_test.go b/agent/router/router_test.go index ae1beefaf4..83de54fed4 100644 --- a/agent/router/router_test.go +++ b/agent/router/router_test.go @@ -117,7 +117,7 @@ func testCluster(self string) *mockCluster { func testRouter(t testing.TB, dc string) *Router { logger := testutil.Logger(t) - return NewRouter(logger, dc, "") + return NewRouter(logger, dc, "", nil) } func TestRouter_Shutdown(t *testing.T) { From 25f47b46e1aa60180fc2a21d3ddc6ea2ec17efff Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Wed, 9 Sep 2020 17:51:51 -0400 Subject: [PATCH 2/8] grpc: move client conn pool to grpc package --- .../{consul/grpc_client.go => grpc/client.go} | 18 +++---- .../resolver/resolver.go} | 53 +++++++++++-------- 2 files changed, 38 insertions(+), 33 deletions(-) rename agent/{consul/grpc_client.go => grpc/client.go} (84%) rename agent/{consul/grpc_resolver.go => grpc/resolver/resolver.go} (84%) diff --git a/agent/consul/grpc_client.go b/agent/grpc/client.go similarity index 84% rename from agent/consul/grpc_client.go rename to agent/grpc/client.go index 6e6d3df115..d2f9f32b27 100644 --- a/agent/consul/grpc_client.go +++ b/agent/grpc/client.go @@ -1,4 +1,4 @@ -package consul +package grpc import ( "context" @@ -6,7 +6,6 @@ import ( "net" "sync" - "github.com/hashicorp/go-hclog" "google.golang.org/grpc" "github.com/hashicorp/consul/agent/metadata" @@ -18,26 +17,24 @@ type ServerProvider interface { Servers() []*metadata.Server } -type GRPCClient struct { - scheme string +type Client struct { serverProvider ServerProvider tlsConfigurator *tlsutil.Configurator grpcConns map[string]*grpc.ClientConn grpcConnLock sync.Mutex } -func NewGRPCClient(logger hclog.Logger, serverProvider ServerProvider, tlsConfigurator *tlsutil.Configurator, scheme string) *GRPCClient { +func NewGRPCClient(serverProvider ServerProvider, tlsConfigurator *tlsutil.Configurator) *Client { // Note we don't actually use the logger anywhere yet but I guess it was added // for future compatibility... - return &GRPCClient{ - scheme: scheme, + return &Client{ serverProvider: serverProvider, tlsConfigurator: tlsConfigurator, grpcConns: make(map[string]*grpc.ClientConn), } } -func (c *GRPCClient) GRPCConn(datacenter string) (*grpc.ClientConn, error) { +func (c *Client) GRPCConn(datacenter string) (*grpc.ClientConn, error) { c.grpcConnLock.Lock() defer c.grpcConnLock.Unlock() @@ -47,13 +44,14 @@ func (c *GRPCClient) GRPCConn(datacenter string) (*grpc.ClientConn, error) { } dialer := newDialer(c.serverProvider, c.tlsConfigurator.OutgoingRPCWrapper()) - conn, err := grpc.Dial(fmt.Sprintf("%s:///server.%s", c.scheme, datacenter), + conn, err := grpc.Dial(fmt.Sprintf("%s:///server.%s", scheme, datacenter), // use WithInsecure mode here because we handle the TLS wrapping in the // custom dialer based on logic around whether the server has TLS enabled. grpc.WithInsecure(), grpc.WithContextDialer(dialer), grpc.WithDisableRetry(), - grpc.WithStatsHandler(grpcStatsHandler), + // TODO: previously this handler was shared with the Handler. Is that necessary? + grpc.WithStatsHandler(&statsHandler{}), grpc.WithBalancerName("pick_first")) if err != nil { return nil, err diff --git a/agent/consul/grpc_resolver.go b/agent/grpc/resolver/resolver.go similarity index 84% rename from agent/consul/grpc_resolver.go rename to agent/grpc/resolver/resolver.go index 883bf4e24d..fa52af5d28 100644 --- a/agent/consul/grpc_resolver.go +++ b/agent/grpc/resolver/resolver.go @@ -1,6 +1,7 @@ -package consul +package grpc import ( + "context" "math/rand" "strings" "sync" @@ -8,19 +9,22 @@ import ( "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/router" - "github.com/hashicorp/serf/serf" "google.golang.org/grpc/resolver" ) -var registerLock sync.Mutex +//var registerLock sync.Mutex +// +//// registerResolverBuilder registers our custom grpc resolver with the given scheme. +//func registerResolverBuilder(datacenter string) *ServerResolverBuilder { +// registerLock.Lock() +// defer registerLock.Unlock() +// grpcResolverBuilder := NewServerResolverBuilder(datacenter) +// resolver.Register(grpcResolverBuilder) +// return grpcResolverBuilder +//} -// registerResolverBuilder registers our custom grpc resolver with the given scheme. -func registerResolverBuilder(scheme, datacenter string, shutdownCh <-chan struct{}) *ServerResolverBuilder { - registerLock.Lock() - defer registerLock.Unlock() - grpcResolverBuilder := NewServerResolverBuilder(scheme, datacenter, shutdownCh) - resolver.Register(grpcResolverBuilder) - return grpcResolverBuilder +type Nodes interface { + NumNodes() int } // ServerResolverBuilder tracks the current server list and keeps any @@ -32,24 +36,24 @@ type ServerResolverBuilder struct { datacenter string servers map[string]*metadata.Server resolvers map[resolver.ClientConn]*ServerResolver - shutdownCh <-chan struct{} + nodes Nodes lock sync.Mutex } -func NewServerResolverBuilder(scheme, datacenter string, shutdownCh <-chan struct{}) *ServerResolverBuilder { +func NewServerResolverBuilder(nodes Nodes, datacenter string) *ServerResolverBuilder { return &ServerResolverBuilder{ - scheme: scheme, datacenter: datacenter, + nodes: nodes, servers: make(map[string]*metadata.Server), resolvers: make(map[resolver.ClientConn]*ServerResolver), } } -// periodicServerRebalance periodically reshuffles the order of server addresses +// Run periodically reshuffles the order of server addresses // within the resolvers to ensure the load is balanced across servers. -func (s *ServerResolverBuilder) periodicServerRebalance(serf *serf.Serf) { +func (s *ServerResolverBuilder) Run(ctx context.Context) { // Compute the rebalance timer based on the number of local servers and nodes. - rebalanceDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), serf.NumNodes()) + rebalanceDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), s.nodes.NumNodes()) timer := time.NewTimer(rebalanceDuration) for { @@ -58,9 +62,9 @@ func (s *ServerResolverBuilder) periodicServerRebalance(serf *serf.Serf) { s.rebalanceResolvers() // Re-compute the wait duration. - newTimerDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), serf.NumNodes()) + newTimerDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), s.nodes.NumNodes()) timer.Reset(newTimerDuration) - case <-s.shutdownCh: + case <-ctx.Done(): timer.Stop() return } @@ -115,7 +119,7 @@ func (s *ServerResolverBuilder) Servers() []*metadata.Server { // Build returns a new ServerResolver for the given ClientConn. The resolver // will keep the ClientConn's state updated based on updates from Serf. -func (s *ServerResolverBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) { +func (s *ServerResolverBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOption) (resolver.Resolver, error) { s.lock.Lock() defer s.lock.Unlock() @@ -142,7 +146,10 @@ func (s *ServerResolverBuilder) Build(target resolver.Target, cc resolver.Client return resolver, nil } -func (s *ServerResolverBuilder) Scheme() string { return s.scheme } +// scheme is the URL scheme used to dial the Consul Server rpc endpoint. +var scheme = "consul" + +func (s *ServerResolverBuilder) Scheme() string { return scheme } // AddServer updates the resolvers' states to include the new server's address. func (s *ServerResolverBuilder) AddServer(server *metadata.Server) { @@ -232,9 +239,9 @@ func (r *ServerResolver) updateAddrsLocked(addrs []resolver.Address) { r.lastAddrs = addrs } -func (s *ServerResolver) Close() { - s.closeCallback() +func (r *ServerResolver) Close() { + r.closeCallback() } // Unneeded since we only update the ClientConn when our server list changes. -func (*ServerResolver) ResolveNow(o resolver.ResolveNowOption) {} +func (*ServerResolver) ResolveNow(_ resolver.ResolveNowOption) {} From bad4d3ff7cfb3e09d7004bc835318049d91da666 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Wed, 9 Sep 2020 18:46:58 -0400 Subject: [PATCH 3/8] grpc: redeuce dependencies, unexport, and add godoc Rename GRPCClient to ClientConnPool. This type appears to be more of a conn pool than a client. The clients receive the connections from this pool. Reduce some dependencies by adjusting the interface baoundaries. Remove the need to create a second slice of Servers, just to pick one and throw the rest away. Unexport serverResolver, it is not used outside the package. Use a RWMutex for ServerResolverBuilder, some locking is read-only. Add more godoc. --- agent/consul/client_test.go | 2 +- agent/grpc/client.go | 82 ++++++++------- agent/grpc/resolver/resolver.go | 174 +++++++++++++++++++------------- agent/setup.go | 3 +- 4 files changed, 152 insertions(+), 109 deletions(-) diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index ea0250454f..6fe7266819 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -480,7 +480,7 @@ func newDefaultDeps(t *testing.T, c *Config) Deps { tls, err := tlsutil.NewConfigurator(c.ToTLSUtilConfig(), logger) require.NoError(t, err, "failed to create tls configuration") - r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter)) + r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), nil) connPool := &pool.ConnPool{ Server: false, diff --git a/agent/grpc/client.go b/agent/grpc/client.go index d2f9f32b27..e65e95a13c 100644 --- a/agent/grpc/client.go +++ b/agent/grpc/client.go @@ -10,61 +10,71 @@ import ( "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" - "github.com/hashicorp/consul/tlsutil" ) -type ServerProvider interface { - Servers() []*metadata.Server +// ClientConnPool creates and stores a connection for each datacenter. +type ClientConnPool struct { + dialer dialer + servers ServerLocator + conns map[string]*grpc.ClientConn + connsLock sync.Mutex } -type Client struct { - serverProvider ServerProvider - tlsConfigurator *tlsutil.Configurator - grpcConns map[string]*grpc.ClientConn - grpcConnLock sync.Mutex +type ServerLocator interface { + // ServerForAddr is used to look up server metadata from an address. + ServerForAddr(addr string) (*metadata.Server, error) + // Scheme returns the url scheme to use to dial the server. This is primarily + // needed for testing multiple agents in parallel, because gRPC requires the + // resolver to be registered globally. + Scheme() string } -func NewGRPCClient(serverProvider ServerProvider, tlsConfigurator *tlsutil.Configurator) *Client { - // Note we don't actually use the logger anywhere yet but I guess it was added - // for future compatibility... - return &Client{ - serverProvider: serverProvider, - tlsConfigurator: tlsConfigurator, - grpcConns: make(map[string]*grpc.ClientConn), +// TLSWrapper wraps a non-TLS connection and returns a connection with TLS +// enabled. +type TLSWrapper func(dc string, conn net.Conn) (net.Conn, error) + +type dialer func(context.Context, string) (net.Conn, error) + +func NewClientConnPool(servers ServerLocator, tls TLSWrapper) *ClientConnPool { + return &ClientConnPool{ + dialer: newDialer(servers, tls), + servers: servers, + conns: make(map[string]*grpc.ClientConn), } } -func (c *Client) GRPCConn(datacenter string) (*grpc.ClientConn, error) { - c.grpcConnLock.Lock() - defer c.grpcConnLock.Unlock() +// ClientConn returns a grpc.ClientConn for the datacenter. If there are no +// existing connections in the pool, a new one will be created, stored in the pool, +// then returned. +func (c *ClientConnPool) ClientConn(datacenter string) (*grpc.ClientConn, error) { + c.connsLock.Lock() + defer c.connsLock.Unlock() - // If there's an existing ClientConn for the given DC, return it. - if conn, ok := c.grpcConns[datacenter]; ok { + if conn, ok := c.conns[datacenter]; ok { return conn, nil } - dialer := newDialer(c.serverProvider, c.tlsConfigurator.OutgoingRPCWrapper()) - conn, err := grpc.Dial(fmt.Sprintf("%s:///server.%s", scheme, datacenter), + conn, err := grpc.Dial(fmt.Sprintf("%s:///server.%s", c.servers.Scheme(), datacenter), // use WithInsecure mode here because we handle the TLS wrapping in the // custom dialer based on logic around whether the server has TLS enabled. grpc.WithInsecure(), - grpc.WithContextDialer(dialer), + grpc.WithContextDialer(c.dialer), grpc.WithDisableRetry(), - // TODO: previously this handler was shared with the Handler. Is that necessary? + // TODO: previously this statsHandler was shared with the Handler. Is that necessary? grpc.WithStatsHandler(&statsHandler{}), + // nolint:staticcheck // there is no other supported alternative to WithBalancerName grpc.WithBalancerName("pick_first")) if err != nil { return nil, err } - c.grpcConns[datacenter] = conn - + c.conns[datacenter] = conn return conn, nil } // newDialer returns a gRPC dialer function that conditionally wraps the connection -// with TLS depending on the given useTLS value. -func newDialer(serverProvider ServerProvider, wrapper tlsutil.DCWrapper) func(context.Context, string) (net.Conn, error) { +// with TLS based on the Server.useTLS value. +func newDialer(servers ServerLocator, wrapper TLSWrapper) func(context.Context, string) (net.Conn, error) { return func(ctx context.Context, addr string) (net.Conn, error) { d := net.Dialer{} conn, err := d.DialContext(ctx, "tcp", addr) @@ -72,17 +82,10 @@ func newDialer(serverProvider ServerProvider, wrapper tlsutil.DCWrapper) func(co return nil, err } - // Check if TLS is enabled for the server. - var found bool - var server *metadata.Server - for _, s := range serverProvider.Servers() { - if s.Addr.String() == addr { - found = true - server = s - } - } - if !found { - return nil, fmt.Errorf("could not find Consul server for address %q", addr) + server, err := servers.ServerForAddr(addr) + if err != nil { + // TODO: should conn be closed in this case, as it is in other error cases? + return nil, err } if server.UseTLS { @@ -107,6 +110,7 @@ func newDialer(serverProvider ServerProvider, wrapper tlsutil.DCWrapper) func(co _, err = conn.Write([]byte{pool.RPCGRPC}) if err != nil { + // TODO: should conn be closed in this case, as it is in other error cases? return nil, err } diff --git a/agent/grpc/resolver/resolver.go b/agent/grpc/resolver/resolver.go index fa52af5d28..82e814ae09 100644 --- a/agent/grpc/resolver/resolver.go +++ b/agent/grpc/resolver/resolver.go @@ -1,7 +1,8 @@ -package grpc +package resolver import ( "context" + "fmt" "math/rand" "strings" "sync" @@ -12,17 +13,21 @@ import ( "google.golang.org/grpc/resolver" ) -//var registerLock sync.Mutex -// -//// registerResolverBuilder registers our custom grpc resolver with the given scheme. -//func registerResolverBuilder(datacenter string) *ServerResolverBuilder { -// registerLock.Lock() -// defer registerLock.Unlock() -// grpcResolverBuilder := NewServerResolverBuilder(datacenter) -// resolver.Register(grpcResolverBuilder) -// return grpcResolverBuilder -//} +var registerLock sync.Mutex +// RegisterWithGRPC registers the ServerResolverBuilder as a grpc/resolver. +// This function exists to synchronize registrations with a lock. +// grpc/resolver.Register expects all registration to happen at init and does +// not allow for concurrent registration. This function exists to support +// parallel testing. +func RegisterWithGRPC(b *ServerResolverBuilder) { + registerLock.Lock() + defer registerLock.Unlock() + resolver.Register(b) +} + +// Nodes provides a count of the number of nodes in the cluster. It is very +// likely implemented by serf to return the number of LAN members. type Nodes interface { NumNodes() int } @@ -30,27 +35,52 @@ type Nodes interface { // ServerResolverBuilder tracks the current server list and keeps any // ServerResolvers updated when changes occur. type ServerResolverBuilder struct { - // Allow overriding the scheme to support parallel tests, since - // the resolver builder is registered globally. - scheme string + // datacenter of the local agent. datacenter string - servers map[string]*metadata.Server - resolvers map[resolver.ClientConn]*ServerResolver - nodes Nodes - lock sync.Mutex + // scheme used to query the server. Defaults to consul. Used to support + // parallel testing because gRPC registers resolvers globally. + scheme string + // servers is an index of Servers by Server.ID + servers map[string]*metadata.Server + // resolvers is an index of connections to the serverResolver which manages + // addresses of servers for that connection. + resolvers map[resolver.ClientConn]*serverResolver + // nodes provides the number of nodes in the cluster. + nodes Nodes + // lock for servers and resolvers. + lock sync.RWMutex } -func NewServerResolverBuilder(nodes Nodes, datacenter string) *ServerResolverBuilder { +var _ resolver.Builder = (*ServerResolverBuilder)(nil) + +type Config struct { + // Datacenter of the local agent. + Datacenter string + // Scheme used to connect to the server. Defaults to consul. + Scheme string +} + +func NewServerResolverBuilder(cfg Config, nodes Nodes) *ServerResolverBuilder { + if cfg.Scheme == "" { + cfg.Scheme = "consul" + } return &ServerResolverBuilder{ - datacenter: datacenter, + scheme: cfg.Scheme, + datacenter: cfg.Datacenter, nodes: nodes, servers: make(map[string]*metadata.Server), - resolvers: make(map[resolver.ClientConn]*ServerResolver), + resolvers: make(map[resolver.ClientConn]*serverResolver), } } -// Run periodically reshuffles the order of server addresses -// within the resolvers to ensure the load is balanced across servers. +// Run periodically reshuffles the order of server addresses within the +// resolvers to ensure the load is balanced across servers. +// +// TODO: this looks very similar to agent/router.Manager.Start, which is the +// only other caller of ComputeRebalanceTimer. Are the values passed to these +// two functions different enough that we need separate goroutines to rebalance? +// or could we have a single thing handle the timers, and call both rebalance +// functions? func (s *ServerResolverBuilder) Run(ctx context.Context) { // Compute the rebalance timer based on the number of local servers and nodes. rebalanceDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), s.nodes.NumNodes()) @@ -73,13 +103,13 @@ func (s *ServerResolverBuilder) Run(ctx context.Context) { // rebalanceResolvers shuffles the server list for resolvers in all datacenters. func (s *ServerResolverBuilder) rebalanceResolvers() { - s.lock.Lock() - defer s.lock.Unlock() + s.lock.RLock() + defer s.lock.RUnlock() for _, resolver := range s.resolvers { // Shuffle the list of addresses using the last list given to the resolver. resolver.addrLock.Lock() - addrs := resolver.lastAddrs + addrs := resolver.addrs rand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] }) @@ -91,8 +121,8 @@ func (s *ServerResolverBuilder) rebalanceResolvers() { // serversInDC returns the number of servers in the given datacenter. func (s *ServerResolverBuilder) serversInDC(dc string) int { - s.lock.Lock() - defer s.lock.Unlock() + s.lock.RLock() + defer s.lock.RUnlock() var serverCount int for _, server := range s.servers { @@ -104,52 +134,49 @@ func (s *ServerResolverBuilder) serversInDC(dc string) int { return serverCount } -// Servers returns metadata for all currently known servers. This is used -// by grpc.ClientConn through our custom dialer. -func (s *ServerResolverBuilder) Servers() []*metadata.Server { - s.lock.Lock() - defer s.lock.Unlock() +// ServerForAddr returns server metadata for a server with the specified address. +func (s *ServerResolverBuilder) ServerForAddr(addr string) (*metadata.Server, error) { + s.lock.RLock() + defer s.lock.RUnlock() - servers := make([]*metadata.Server, 0, len(s.servers)) for _, server := range s.servers { - servers = append(servers, server) + if server.Addr.String() == addr { + return server, nil + } } - return servers + return nil, fmt.Errorf("failed to find Consul server for address %q", addr) } -// Build returns a new ServerResolver for the given ClientConn. The resolver +// Build returns a new serverResolver for the given ClientConn. The resolver // will keep the ClientConn's state updated based on updates from Serf. func (s *ServerResolverBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOption) (resolver.Resolver, error) { s.lock.Lock() defer s.lock.Unlock() - // If there's already a resolver for this datacenter, return it. - datacenter := strings.TrimPrefix(target.Endpoint, "server.") + // If there's already a resolver for this connection, return it. + // TODO(streaming): how would this happen since we already cache connections in ClientConnPool? if resolver, ok := s.resolvers[cc]; ok { return resolver, nil } // Make a new resolver for the dc and add it to the list of active ones. - resolver := &ServerResolver{ + datacenter := strings.TrimPrefix(target.Endpoint, "server.") + resolver := &serverResolver{ datacenter: datacenter, clientConn: cc, + close: func() { + s.lock.Lock() + defer s.lock.Unlock() + delete(s.resolvers, cc) + }, } resolver.updateAddrs(s.getDCAddrs(datacenter)) - resolver.closeCallback = func() { - s.lock.Lock() - defer s.lock.Unlock() - delete(s.resolvers, cc) - } s.resolvers[cc] = resolver - return resolver, nil } -// scheme is the URL scheme used to dial the Consul Server rpc endpoint. -var scheme = "consul" - -func (s *ServerResolverBuilder) Scheme() string { return scheme } +func (s *ServerResolverBuilder) Scheme() string { return s.scheme } // AddServer updates the resolvers' states to include the new server's address. func (s *ServerResolverBuilder) AddServer(server *metadata.Server) { @@ -182,7 +209,7 @@ func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) { } // getDCAddrs returns a list of the server addresses for the given datacenter. -// This method assumes the lock is held. +// This method requires that lock is held for reads. func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address { var addrs []resolver.Address for _, server := range s.servers { @@ -199,28 +226,39 @@ func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address { return addrs } -// ServerResolver is a grpc Resolver that will keep a grpc.ClientConn up to date +// serverResolver is a grpc Resolver that will keep a grpc.ClientConn up to date // on the list of server addresses to use. -type ServerResolver struct { - datacenter string - clientConn resolver.ClientConn - closeCallback func() +type serverResolver struct { + // datacenter that can be reached by the clientConn. Used by ServerResolverBuilder + // to filter resolvers for those in a specific datacenter. + datacenter string - lastAddrs []resolver.Address - addrLock sync.Mutex + // clientConn that this resolver is providing addresses for. + clientConn resolver.ClientConn + + // close is used by ServerResolverBuilder to remove this resolver from the + // index of resolvers. It is called by grpc when the connection is closed. + close func() + + // addrs stores the list of addresses passed to updateAddrs, so that they + // can be rebalanced periodically by ServerResolverBuilder. + addrs []resolver.Address + addrLock sync.Mutex } -// updateAddrs updates this ServerResolver's ClientConn to use the given set of +var _ resolver.Resolver = (*serverResolver)(nil) + +// updateAddrs updates this serverResolver's ClientConn to use the given set of // addrs. -func (r *ServerResolver) updateAddrs(addrs []resolver.Address) { +func (r *serverResolver) updateAddrs(addrs []resolver.Address) { r.addrLock.Lock() defer r.addrLock.Unlock() r.updateAddrsLocked(addrs) } -// updateAddrsLocked updates this ServerResolver's ClientConn to use the given -// set of addrs. addrLock must be held by calleer. -func (r *ServerResolver) updateAddrsLocked(addrs []resolver.Address) { +// updateAddrsLocked updates this serverResolver's ClientConn to use the given +// set of addrs. addrLock must be held by caller. +func (r *serverResolver) updateAddrsLocked(addrs []resolver.Address) { // Only pass the first address initially, which will cause the // balancer to spin down the connection for its previous first address // if it is different. If we don't do this, it will keep using the old @@ -236,12 +274,12 @@ func (r *ServerResolver) updateAddrsLocked(addrs []resolver.Address) { // for failover. r.clientConn.UpdateState(resolver.State{Addresses: addrs}) - r.lastAddrs = addrs + r.addrs = addrs } -func (r *ServerResolver) Close() { - r.closeCallback() +func (r *serverResolver) Close() { + r.close() } -// Unneeded since we only update the ClientConn when our server list changes. -func (*ServerResolver) ResolveNow(_ resolver.ResolveNowOption) {} +// ResolveNow is not used +func (*serverResolver) ResolveNow(_ resolver.ResolveNowOption) {} diff --git a/agent/setup.go b/agent/setup.go index 18a0be0c38..d56419680c 100644 --- a/agent/setup.go +++ b/agent/setup.go @@ -82,7 +82,8 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error) d.Cache = cache.New(cfg.Cache) d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator) - d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter)) + // TODO: set grpcServerTracker, requires serf to be setup before this. + d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), nil) acConf := autoconf.Config{ DirectRPC: d.ConnPool, From 07b4507f1e185d513b72ccb2140adfe17ba8c0a0 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Fri, 11 Sep 2020 12:19:52 -0400 Subject: [PATCH 4/8] router: remove grpcServerTracker from managers It only needs to be refereced from the Router, because there is only 1 instance, and the Router can call AddServer/RemoveServer like it does on the Manager. --- agent/router/manager.go | 17 +---------------- agent/router/manager_internal_test.go | 6 +++--- agent/router/manager_test.go | 8 ++++---- agent/router/router.go | 4 +++- 4 files changed, 11 insertions(+), 24 deletions(-) diff --git a/agent/router/manager.go b/agent/router/manager.go index 0e64c5c904..2052eb02d7 100644 --- a/agent/router/manager.go +++ b/agent/router/manager.go @@ -98,10 +98,6 @@ type Manager struct { // client.ConnPool. connPoolPinger Pinger - // grpcServerTracker is used to balance grpc connections across servers, - // and has callbacks for adding or removing a server. - grpcServerTracker ServerTracker - // serverName has the name of the managers's server. This is used to // short-circuit pinging to itself. serverName string @@ -123,7 +119,6 @@ type Manager struct { func (m *Manager) AddServer(s *metadata.Server) { m.listLock.Lock() defer m.listLock.Unlock() - m.grpcServerTracker.AddServer(s) l := m.getServerList() // Check if this server is known @@ -256,11 +251,6 @@ func (m *Manager) CheckServers(fn func(srv *metadata.Server) bool) { _ = m.checkServers(fn) } -// Servers returns the current list of servers. -func (m *Manager) Servers() []*metadata.Server { - return m.getServerList().servers -} - // getServerList is a convenience method which hides the locking semantics // of atomic.Value from the caller. func (m *Manager) getServerList() serverList { @@ -277,19 +267,15 @@ func (m *Manager) saveServerList(l serverList) { } // New is the only way to safely create a new Manager struct. -func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, tracker ServerTracker, serverName string) (m *Manager) { +func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, serverName string) (m *Manager) { if logger == nil { logger = hclog.New(&hclog.LoggerOptions{}) } - if tracker == nil { - tracker = NoOpServerTracker{} - } m = new(Manager) m.logger = logger.Named(logging.Manager) m.clusterInfo = clusterInfo // can't pass *consul.Client: import cycle m.connPoolPinger = connPoolPinger // can't pass *consul.ConnPool: import cycle - m.grpcServerTracker = tracker m.rebalanceTimer = time.NewTimer(clientRPCMinReuseDuration) m.shutdownCh = shutdownCh m.serverName = serverName @@ -492,7 +478,6 @@ func (m *Manager) reconcileServerList(l *serverList) bool { func (m *Manager) RemoveServer(s *metadata.Server) { m.listLock.Lock() defer m.listLock.Unlock() - m.grpcServerTracker.RemoveServer(s) l := m.getServerList() // Remove the server if known diff --git a/agent/router/manager_internal_test.go b/agent/router/manager_internal_test.go index 63838f4972..76d9512168 100644 --- a/agent/router/manager_internal_test.go +++ b/agent/router/manager_internal_test.go @@ -54,14 +54,14 @@ func (s *fauxSerf) NumNodes() int { func testManager() (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, nil, "") + m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, "") return m } func testManagerFailProb(failPct float64) (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, nil, "") + m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "") return m } @@ -300,7 +300,7 @@ func TestManagerInternal_refreshServerRebalanceTimer(t *testing.T) { shutdownCh := make(chan struct{}) for _, s := range clusters { - m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, nil, "") + m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, "") for i := 0; i < s.numServers; i++ { nodeName := fmt.Sprintf("s%02d", i) m.AddServer(&metadata.Server{Name: nodeName}) diff --git a/agent/router/manager_test.go b/agent/router/manager_test.go index 6c3a83816c..c7e1f299ca 100644 --- a/agent/router/manager_test.go +++ b/agent/router/manager_test.go @@ -57,21 +57,21 @@ func (s *fauxSerf) NumNodes() int { func testManager(t testing.TB) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, nil, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "") return m } func testManagerFailProb(t testing.TB, failPct float64) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, nil, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "") return m } func testManagerFailAddr(t testing.TB, failAddr net.Addr) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, nil, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, "") return m } @@ -195,7 +195,7 @@ func TestServers_FindServer(t *testing.T) { func TestServers_New(t *testing.T) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, nil, "") + m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "") if m == nil { t.Fatalf("Manager nil") } diff --git a/agent/router/router.go b/agent/router/router.go index 63f6e88ee1..9694e927db 100644 --- a/agent/router/router.go +++ b/agent/router/router.go @@ -259,7 +259,7 @@ func (r *Router) maybeInitializeManager(area *areaInfo, dc string) *Manager { } shutdownCh := make(chan struct{}) - manager := New(r.logger, shutdownCh, area.cluster, area.pinger, nil, r.serverName) + manager := New(r.logger, shutdownCh, area.cluster, area.pinger, r.serverName) info = &managerInfo{ manager: manager, shutdownCh: shutdownCh, @@ -286,6 +286,7 @@ func (r *Router) addServer(area *areaInfo, s *metadata.Server) error { } manager.AddServer(s) + r.grpcServerTracker.AddServer(s) return nil } @@ -321,6 +322,7 @@ func (r *Router) RemoveServer(areaID types.AreaID, s *metadata.Server) error { return nil } info.manager.RemoveServer(s) + r.grpcServerTracker.RemoveServer(s) // If this manager is empty then remove it so we don't accumulate cruft // and waste time during request routing. From 2273673500661cc706467c24b821c6d551d16a7e Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Fri, 11 Sep 2020 14:15:02 -0400 Subject: [PATCH 5/8] grpc: restore integration tests for grpc client conn pool Add a fake rpc Listener --- agent/grpc/client.go | 5 +- agent/grpc/client_test.go | 92 +++++++++++++++++++++++++++++++ agent/grpc/handler.go | 21 +++++-- agent/grpc/resolver/resolver.go | 3 +- agent/grpc/server_test.go | 97 +++++++++++++++++++++++++++++++++ 5 files changed, 209 insertions(+), 9 deletions(-) create mode 100644 agent/grpc/client_test.go diff --git a/agent/grpc/client.go b/agent/grpc/client.go index e65e95a13c..71f16c7c31 100644 --- a/agent/grpc/client.go +++ b/agent/grpc/client.go @@ -54,14 +54,15 @@ func (c *ClientConnPool) ClientConn(datacenter string) (*grpc.ClientConn, error) return conn, nil } - conn, err := grpc.Dial(fmt.Sprintf("%s:///server.%s", c.servers.Scheme(), datacenter), + conn, err := grpc.Dial( + fmt.Sprintf("%s:///server.%s", c.servers.Scheme(), datacenter), // use WithInsecure mode here because we handle the TLS wrapping in the // custom dialer based on logic around whether the server has TLS enabled. grpc.WithInsecure(), grpc.WithContextDialer(c.dialer), grpc.WithDisableRetry(), // TODO: previously this statsHandler was shared with the Handler. Is that necessary? - grpc.WithStatsHandler(&statsHandler{}), + grpc.WithStatsHandler(newStatsHandler()), // nolint:staticcheck // there is no other supported alternative to WithBalancerName grpc.WithBalancerName("pick_first")) if err != nil { diff --git a/agent/grpc/client_test.go b/agent/grpc/client_test.go new file mode 100644 index 0000000000..d8ea50dd8b --- /dev/null +++ b/agent/grpc/client_test.go @@ -0,0 +1,92 @@ +package grpc + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/hashicorp/consul/agent/grpc/internal/testservice" + "github.com/hashicorp/consul/agent/grpc/resolver" + "github.com/hashicorp/consul/agent/metadata" + "github.com/stretchr/testify/require" +) + +func TestNewDialer(t *testing.T) { + // TODO: conn is closed on errors + // TODO: with TLS enabled +} + +func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { + count := 4 + cfg := resolver.Config{Datacenter: "dc1", Scheme: newScheme(t.Name())} + res := resolver.NewServerResolverBuilder(cfg, fakeNodes{num: count}) + resolver.RegisterWithGRPC(res) + pool := NewClientConnPool(res, nil) + + for i := 0; i < count; i++ { + name := fmt.Sprintf("server-%d", i) + srv := newTestServer(t, name, "dc1") + res.AddServer(srv.Metadata()) + t.Cleanup(srv.shutdown) + } + + conn, err := pool.ClientConn("dc1") + require.NoError(t, err) + client := testservice.NewSimpleClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + first, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + res.RemoveServer(&metadata.Server{ID: first.ServerName, Datacenter: "dc1"}) + + resp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + require.NotEqual(t, resp.ServerName, first.ServerName) +} + +func newScheme(n string) string { + s := strings.Replace(n, "/", "", -1) + s = strings.Replace(s, "_", "", -1) + return strings.ToLower(s) +} + +type fakeNodes struct { + num int +} + +func (n fakeNodes) NumNodes() int { + return n.num +} + +func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { + dcs := []string{"dc1", "dc2", "dc3"} + + cfg := resolver.Config{Datacenter: "dc1", Scheme: newScheme(t.Name())} + res := resolver.NewServerResolverBuilder(cfg, fakeNodes{num: 1}) + resolver.RegisterWithGRPC(res) + pool := NewClientConnPool(res, nil) + + for _, dc := range dcs { + name := "server-0-" + dc + srv := newTestServer(t, name, dc) + res.AddServer(srv.Metadata()) + t.Cleanup(srv.shutdown) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + for _, dc := range dcs { + conn, err := pool.ClientConn(dc) + require.NoError(t, err) + client := testservice.NewSimpleClient(conn) + + resp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + require.Equal(t, resp.Datacenter, dc) + } +} diff --git a/agent/grpc/handler.go b/agent/grpc/handler.go index c3af7f38c4..c43c1ba1e2 100644 --- a/agent/grpc/handler.go +++ b/agent/grpc/handler.go @@ -21,10 +21,8 @@ func NewHandler(addr net.Addr) *Handler { // TODO(streaming): add gRPC services to srv here - return &Handler{ - srv: srv, - listener: &chanListener{addr: addr, conns: make(chan net.Conn)}, - } + lis := &chanListener{addr: addr, conns: make(chan net.Conn)} + return &Handler{srv: srv, listener: lis} } // Handler implements a handler for the rpc server listener, and the @@ -57,15 +55,26 @@ type chanListener struct { // Accept blocks until a connection is received from Handle, and then returns the // connection. Accept implements part of the net.Listener interface for grpc.Server. func (l *chanListener) Accept() (net.Conn, error) { - return <-l.conns, nil + select { + case c, ok := <-l.conns: + if !ok { + return nil, &net.OpError{ + Op: "accept", + Net: l.addr.Network(), + Addr: l.addr, + Err: fmt.Errorf("listener closed"), + } + } + return c, nil + } } func (l *chanListener) Addr() net.Addr { return l.addr } -// Close does nothing. The connections are managed by the caller. func (l *chanListener) Close() error { + close(l.conns) return nil } diff --git a/agent/grpc/resolver/resolver.go b/agent/grpc/resolver/resolver.go index 82e814ae09..3bf66b74c1 100644 --- a/agent/grpc/resolver/resolver.go +++ b/agent/grpc/resolver/resolver.go @@ -40,7 +40,8 @@ type ServerResolverBuilder struct { // scheme used to query the server. Defaults to consul. Used to support // parallel testing because gRPC registers resolvers globally. scheme string - // servers is an index of Servers by Server.ID + // servers is an index of Servers by Server.ID. The map contains server IDs + // for all datacenters, so it assumes the ID is globally unique. servers map[string]*metadata.Server // resolvers is an index of connections to the serverResolver which manages // addresses of servers for that connection. diff --git a/agent/grpc/server_test.go b/agent/grpc/server_test.go index b7843ff011..b4cb9c7834 100644 --- a/agent/grpc/server_test.go +++ b/agent/grpc/server_test.go @@ -2,11 +2,66 @@ package grpc import ( "context" + "fmt" + "io" + "net" + "testing" "time" "github.com/hashicorp/consul/agent/grpc/internal/testservice" + "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/agent/pool" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) +type testServer struct { + addr net.Addr + name string + dc string + shutdown func() +} + +func (s testServer) Metadata() *metadata.Server { + return &metadata.Server{ID: s.name, Datacenter: s.dc, Addr: s.addr} +} + +func newTestServer(t *testing.T, name string, dc string) testServer { + addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} + handler := NewHandler(addr) + + testservice.RegisterSimpleServer(handler.srv, &simple{name: name, dc: dc}) + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + rpc := &fakeRPCListener{t: t, handler: handler} + + g := errgroup.Group{} + g.Go(func() error { + return rpc.listen(lis) + }) + g.Go(func() error { + return handler.Run() + }) + return testServer{ + addr: lis.Addr(), + name: name, + dc: dc, + shutdown: func() { + if err := lis.Close(); err != nil { + t.Logf("listener closed with error: %v", err) + } + if err := handler.Shutdown(); err != nil { + t.Logf("grpc server shutdown: %v", err) + } + if err := g.Wait(); err != nil { + t.Logf("grpc server error: %v", err) + } + }, + } +} + type simple struct { name string dc string @@ -26,3 +81,45 @@ func (s *simple) Flow(_ *testservice.Req, flow testservice.Simple_FlowServer) er func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice.Resp, error) { return &testservice.Resp{ServerName: s.name, Datacenter: s.dc}, nil } + +// fakeRPCListener mimics agent/consul.Server.listen to handle the RPCType byte. +// In the future we should be able to refactor Server and extract this RPC +// handling logic so that we don't need to use a fake. +// For now, since this logic is in agent/consul, we can't easily use Server.listen +// so we fake it. +type fakeRPCListener struct { + t *testing.T + handler *Handler +} + +func (f *fakeRPCListener) listen(listener net.Listener) error { + for { + conn, err := listener.Accept() + if err != nil { + return err + } + + go f.handleConn(conn) + } +} + +func (f *fakeRPCListener) handleConn(conn net.Conn) { + buf := make([]byte, 1) + + if _, err := conn.Read(buf); err != nil { + if err != io.EOF { + fmt.Println("ERROR", err.Error()) + } + conn.Close() + return + } + typ := pool.RPCType(buf[0]) + + if typ == pool.RPCGRPC { + f.handler.Handle(conn) + return + } + + fmt.Println("ERROR: unexpected byte", typ) + conn.Close() +} From 229479335724ec3dcb54cca25380d7f76b0eda21 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Mon, 14 Sep 2020 16:16:44 -0400 Subject: [PATCH 6/8] agent/grpc: use router.Manager to handle the rebalance The router.Manager is already rebalancing servers for other connection pools, so it can call into our resolver to do the same. This change allows us to remove the serf dependency from resolverBuilder, and remove Datacenter from the config. Also revert the change to refreshServerRebalanceTimer --- agent/grpc/client_test.go | 57 ++++++++++++--- agent/grpc/resolver/resolver.go | 100 ++++++-------------------- agent/router/grpc.go | 18 +++-- agent/router/manager.go | 21 +++--- agent/router/manager_internal_test.go | 8 ++- agent/router/manager_test.go | 10 +-- agent/router/router.go | 3 +- agent/setup.go | 7 +- 8 files changed, 113 insertions(+), 111 deletions(-) diff --git a/agent/grpc/client_test.go b/agent/grpc/client_test.go index d8ea50dd8b..0cc8cb20ff 100644 --- a/agent/grpc/client_test.go +++ b/agent/grpc/client_test.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/consul/agent/grpc/internal/testservice" "github.com/hashicorp/consul/agent/grpc/resolver" "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/stretchr/testify/require" ) @@ -20,8 +21,8 @@ func TestNewDialer(t *testing.T) { func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { count := 4 - cfg := resolver.Config{Datacenter: "dc1", Scheme: newScheme(t.Name())} - res := resolver.NewServerResolverBuilder(cfg, fakeNodes{num: count}) + cfg := resolver.Config{Scheme: newScheme(t.Name())} + res := resolver.NewServerResolverBuilder(cfg) resolver.RegisterWithGRPC(res) pool := NewClientConnPool(res, nil) @@ -41,6 +42,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { first, err := client.Something(ctx, &testservice.Req{}) require.NoError(t, err) + res.RemoveServer(&metadata.Server{ID: first.ServerName, Datacenter: "dc1"}) resp, err := client.Something(ctx, &testservice.Req{}) @@ -54,19 +56,56 @@ func newScheme(n string) string { return strings.ToLower(s) } -type fakeNodes struct { - num int -} +func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) { + count := 4 + cfg := resolver.Config{Scheme: newScheme(t.Name())} + res := resolver.NewServerResolverBuilder(cfg) + resolver.RegisterWithGRPC(res) + pool := NewClientConnPool(res, nil) -func (n fakeNodes) NumNodes() int { - return n.num + for i := 0; i < count; i++ { + name := fmt.Sprintf("server-%d", i) + srv := newTestServer(t, name, "dc1") + res.AddServer(srv.Metadata()) + t.Cleanup(srv.shutdown) + } + + conn, err := pool.ClientConn("dc1") + require.NoError(t, err) + client := testservice.NewSimpleClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + first, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + + t.Run("rebalance a different DC, does nothing", func(t *testing.T) { + res.NewRebalancer("dc-other")() + + resp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + require.Equal(t, resp.ServerName, first.ServerName) + }) + + t.Run("rebalance the dc", func(t *testing.T) { + // Rebalance is random, but if we repeat it a few times it should give us a + // new server. + retry.RunWith(fastRetry, t, func(r *retry.R) { + res.NewRebalancer("dc1")() + + resp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(r, err) + require.NotEqual(r, resp.ServerName, first.ServerName) + }) + }) } func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { dcs := []string{"dc1", "dc2", "dc3"} - cfg := resolver.Config{Datacenter: "dc1", Scheme: newScheme(t.Name())} - res := resolver.NewServerResolverBuilder(cfg, fakeNodes{num: 1}) + cfg := resolver.Config{Scheme: newScheme(t.Name())} + res := resolver.NewServerResolverBuilder(cfg) resolver.RegisterWithGRPC(res) pool := NewClientConnPool(res, nil) diff --git a/agent/grpc/resolver/resolver.go b/agent/grpc/resolver/resolver.go index 3bf66b74c1..d35def0d21 100644 --- a/agent/grpc/resolver/resolver.go +++ b/agent/grpc/resolver/resolver.go @@ -1,15 +1,12 @@ package resolver import ( - "context" "fmt" "math/rand" "strings" "sync" - "time" "github.com/hashicorp/consul/agent/metadata" - "github.com/hashicorp/consul/agent/router" "google.golang.org/grpc/resolver" ) @@ -26,17 +23,9 @@ func RegisterWithGRPC(b *ServerResolverBuilder) { resolver.Register(b) } -// Nodes provides a count of the number of nodes in the cluster. It is very -// likely implemented by serf to return the number of LAN members. -type Nodes interface { - NumNodes() int -} - // ServerResolverBuilder tracks the current server list and keeps any // ServerResolvers updated when changes occur. type ServerResolverBuilder struct { - // datacenter of the local agent. - datacenter string // scheme used to query the server. Defaults to consul. Used to support // parallel testing because gRPC registers resolvers globally. scheme string @@ -46,8 +35,6 @@ type ServerResolverBuilder struct { // resolvers is an index of connections to the serverResolver which manages // addresses of servers for that connection. resolvers map[resolver.ClientConn]*serverResolver - // nodes provides the number of nodes in the cluster. - nodes Nodes // lock for servers and resolvers. lock sync.RWMutex } @@ -55,86 +42,45 @@ type ServerResolverBuilder struct { var _ resolver.Builder = (*ServerResolverBuilder)(nil) type Config struct { - // Datacenter of the local agent. - Datacenter string // Scheme used to connect to the server. Defaults to consul. Scheme string } -func NewServerResolverBuilder(cfg Config, nodes Nodes) *ServerResolverBuilder { +func NewServerResolverBuilder(cfg Config) *ServerResolverBuilder { if cfg.Scheme == "" { cfg.Scheme = "consul" } return &ServerResolverBuilder{ - scheme: cfg.Scheme, - datacenter: cfg.Datacenter, - nodes: nodes, - servers: make(map[string]*metadata.Server), - resolvers: make(map[resolver.ClientConn]*serverResolver), + scheme: cfg.Scheme, + servers: make(map[string]*metadata.Server), + resolvers: make(map[resolver.ClientConn]*serverResolver), } } -// Run periodically reshuffles the order of server addresses within the -// resolvers to ensure the load is balanced across servers. -// -// TODO: this looks very similar to agent/router.Manager.Start, which is the -// only other caller of ComputeRebalanceTimer. Are the values passed to these -// two functions different enough that we need separate goroutines to rebalance? -// or could we have a single thing handle the timers, and call both rebalance -// functions? -func (s *ServerResolverBuilder) Run(ctx context.Context) { - // Compute the rebalance timer based on the number of local servers and nodes. - rebalanceDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), s.nodes.NumNodes()) - timer := time.NewTimer(rebalanceDuration) +// Rebalance shuffles the server list for resolvers in all datacenters. +func (s *ServerResolverBuilder) NewRebalancer(dc string) func() { + return func() { + s.lock.RLock() + defer s.lock.RUnlock() - for { - select { - case <-timer.C: - s.rebalanceResolvers() - - // Re-compute the wait duration. - newTimerDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), s.nodes.NumNodes()) - timer.Reset(newTimerDuration) - case <-ctx.Done(): - timer.Stop() - return + for _, resolver := range s.resolvers { + if resolver.datacenter != dc { + continue + } + // Shuffle the list of addresses using the last list given to the resolver. + resolver.addrLock.Lock() + addrs := resolver.addrs + // TODO: seed this rand, so it is a little more random-like + rand.Shuffle(len(addrs), func(i, j int) { + addrs[i], addrs[j] = addrs[j], addrs[i] + }) + // Pass the shuffled list to the resolver. + resolver.updateAddrsLocked(addrs) + resolver.addrLock.Unlock() } } } -// rebalanceResolvers shuffles the server list for resolvers in all datacenters. -func (s *ServerResolverBuilder) rebalanceResolvers() { - s.lock.RLock() - defer s.lock.RUnlock() - - for _, resolver := range s.resolvers { - // Shuffle the list of addresses using the last list given to the resolver. - resolver.addrLock.Lock() - addrs := resolver.addrs - rand.Shuffle(len(addrs), func(i, j int) { - addrs[i], addrs[j] = addrs[j], addrs[i] - }) - // Pass the shuffled list to the resolver. - resolver.updateAddrsLocked(addrs) - resolver.addrLock.Unlock() - } -} - -// serversInDC returns the number of servers in the given datacenter. -func (s *ServerResolverBuilder) serversInDC(dc string) int { - s.lock.RLock() - defer s.lock.RUnlock() - - var serverCount int - for _, server := range s.servers { - if server.Datacenter == dc { - serverCount++ - } - } - - return serverCount -} - // ServerForAddr returns server metadata for a server with the specified address. func (s *ServerResolverBuilder) ServerForAddr(addr string) (*metadata.Server, error) { s.lock.RLock() diff --git a/agent/router/grpc.go b/agent/router/grpc.go index 0a50992811..c4fe96d25f 100644 --- a/agent/router/grpc.go +++ b/agent/router/grpc.go @@ -2,19 +2,29 @@ package router import "github.com/hashicorp/consul/agent/metadata" -// ServerTracker is a wrapper around consul.ServerResolverBuilder to prevent a -// cyclic import dependency. +// ServerTracker is called when Router is notified of a server being added or +// removed. type ServerTracker interface { + NewRebalancer(dc string) func() AddServer(*metadata.Server) RemoveServer(*metadata.Server) } +// Rebalancer is called periodically to re-order the servers so that the load on the +// servers is evenly balanced. +type Rebalancer func() + // NoOpServerTracker is a ServerTracker that does nothing. Used when gRPC is not // enabled. type NoOpServerTracker struct{} -// AddServer implements ServerTracker +// Rebalance does nothing +func (NoOpServerTracker) NewRebalancer(string) func() { + return func() {} +} + +// AddServer does nothing func (NoOpServerTracker) AddServer(*metadata.Server) {} -// RemoveServer implements ServerTracker +// RemoveServer does nothing func (NoOpServerTracker) RemoveServer(*metadata.Server) {} diff --git a/agent/router/manager.go b/agent/router/manager.go index 2052eb02d7..4aaab97597 100644 --- a/agent/router/manager.go +++ b/agent/router/manager.go @@ -98,6 +98,8 @@ type Manager struct { // client.ConnPool. connPoolPinger Pinger + rebalancer Rebalancer + // serverName has the name of the managers's server. This is used to // short-circuit pinging to itself. serverName string @@ -267,7 +269,7 @@ func (m *Manager) saveServerList(l serverList) { } // New is the only way to safely create a new Manager struct. -func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, serverName string) (m *Manager) { +func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, serverName string, rb Rebalancer) (m *Manager) { if logger == nil { logger = hclog.New(&hclog.LoggerOptions{}) } @@ -278,6 +280,7 @@ func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfC m.connPoolPinger = connPoolPinger // can't pass *consul.ConnPool: import cycle m.rebalanceTimer = time.NewTimer(clientRPCMinReuseDuration) m.shutdownCh = shutdownCh + m.rebalancer = rb m.serverName = serverName atomic.StoreInt32(&m.offline, 1) @@ -498,22 +501,17 @@ func (m *Manager) RemoveServer(s *metadata.Server) { func (m *Manager) refreshServerRebalanceTimer() time.Duration { l := m.getServerList() numServers := len(l.servers) - connRebalanceTimeout := ComputeRebalanceTimer(numServers, m.clusterInfo.NumNodes()) - - m.rebalanceTimer.Reset(connRebalanceTimeout) - return connRebalanceTimeout -} - -// ComputeRebalanceTimer returns a time to wait before rebalancing connections given -// a number of servers and LAN nodes. -func ComputeRebalanceTimer(numServers, numLANMembers int) time.Duration { // Limit this connection's life based on the size (and health) of the // cluster. Never rebalance a connection more frequently than // connReuseLowWatermarkDuration, and make sure we never exceed // clusterWideRebalanceConnsPerSec operations/s across numLANMembers. clusterWideRebalanceConnsPerSec := float64(numServers * newRebalanceConnsPerSecPerServer) connReuseLowWatermarkDuration := clientRPCMinReuseDuration + lib.RandomStagger(clientRPCMinReuseDuration/clientRPCJitterFraction) - return lib.RateScaledInterval(clusterWideRebalanceConnsPerSec, connReuseLowWatermarkDuration, numLANMembers) + numLANMembers := m.clusterInfo.NumNodes() + connRebalanceTimeout := lib.RateScaledInterval(clusterWideRebalanceConnsPerSec, connReuseLowWatermarkDuration, numLANMembers) + + m.rebalanceTimer.Reset(connRebalanceTimeout) + return connRebalanceTimeout } // ResetRebalanceTimer resets the rebalance timer. This method exists for @@ -534,6 +532,7 @@ func (m *Manager) Start() { for { select { case <-m.rebalanceTimer.C: + m.rebalancer() m.RebalanceServers() m.refreshServerRebalanceTimer() diff --git a/agent/router/manager_internal_test.go b/agent/router/manager_internal_test.go index 76d9512168..05807e2070 100644 --- a/agent/router/manager_internal_test.go +++ b/agent/router/manager_internal_test.go @@ -54,14 +54,16 @@ func (s *fauxSerf) NumNodes() int { func testManager() (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, "") + m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, "", noopRebalancer) return m } +func noopRebalancer() {} + func testManagerFailProb(failPct float64) (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "") + m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "", noopRebalancer) return m } @@ -300,7 +302,7 @@ func TestManagerInternal_refreshServerRebalanceTimer(t *testing.T) { shutdownCh := make(chan struct{}) for _, s := range clusters { - m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, "") + m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, "", noopRebalancer) for i := 0; i < s.numServers; i++ { nodeName := fmt.Sprintf("s%02d", i) m.AddServer(&metadata.Server{Name: nodeName}) diff --git a/agent/router/manager_test.go b/agent/router/manager_test.go index c7e1f299ca..dc3628f1bd 100644 --- a/agent/router/manager_test.go +++ b/agent/router/manager_test.go @@ -57,21 +57,23 @@ func (s *fauxSerf) NumNodes() int { func testManager(t testing.TB) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "", noopRebalancer) return m } +func noopRebalancer() {} + func testManagerFailProb(t testing.TB, failPct float64) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "", noopRebalancer) return m } func testManagerFailAddr(t testing.TB, failAddr net.Addr) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, "", noopRebalancer) return m } @@ -195,7 +197,7 @@ func TestServers_FindServer(t *testing.T) { func TestServers_New(t *testing.T) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "") + m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "", noopRebalancer) if m == nil { t.Fatalf("Manager nil") } diff --git a/agent/router/router.go b/agent/router/router.go index 9694e927db..8244745c3b 100644 --- a/agent/router/router.go +++ b/agent/router/router.go @@ -259,7 +259,8 @@ func (r *Router) maybeInitializeManager(area *areaInfo, dc string) *Manager { } shutdownCh := make(chan struct{}) - manager := New(r.logger, shutdownCh, area.cluster, area.pinger, r.serverName) + rb := r.grpcServerTracker.NewRebalancer(dc) + manager := New(r.logger, shutdownCh, area.cluster, area.pinger, r.serverName, rb) info = &managerInfo{ manager: manager, shutdownCh: shutdownCh, diff --git a/agent/setup.go b/agent/setup.go index d56419680c..454bfa510d 100644 --- a/agent/setup.go +++ b/agent/setup.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/consul" + "github.com/hashicorp/consul/agent/grpc/resolver" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/router" "github.com/hashicorp/consul/agent/token" @@ -82,8 +83,10 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error) d.Cache = cache.New(cfg.Cache) d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator) - // TODO: set grpcServerTracker, requires serf to be setup before this. - d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), nil) + // TODO(streaming): setConfig.Scheme name for tests + builder := resolver.NewServerResolverBuilder(resolver.Config{}) + resolver.RegisterWithGRPC(builder) + d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder) acConf := autoconf.Config{ DirectRPC: d.ConnPool, From e6ffd987a3bd359fdb3246a90027706543489039 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Tue, 15 Sep 2020 13:51:25 -0400 Subject: [PATCH 7/8] agent/grpc: seed the rand for shuffling servers --- agent/grpc/resolver/resolver.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/agent/grpc/resolver/resolver.go b/agent/grpc/resolver/resolver.go index d35def0d21..b34aad72f4 100644 --- a/agent/grpc/resolver/resolver.go +++ b/agent/grpc/resolver/resolver.go @@ -5,6 +5,7 @@ import ( "math/rand" "strings" "sync" + "time" "github.com/hashicorp/consul/agent/metadata" "google.golang.org/grpc/resolver" @@ -59,6 +60,7 @@ func NewServerResolverBuilder(cfg Config) *ServerResolverBuilder { // Rebalance shuffles the server list for resolvers in all datacenters. func (s *ServerResolverBuilder) NewRebalancer(dc string) func() { + shuffler := rand.New(rand.NewSource(time.Now().UnixNano())) return func() { s.lock.RLock() defer s.lock.RUnlock() @@ -70,8 +72,7 @@ func (s *ServerResolverBuilder) NewRebalancer(dc string) func() { // Shuffle the list of addresses using the last list given to the resolver. resolver.addrLock.Lock() addrs := resolver.addrs - // TODO: seed this rand, so it is a little more random-like - rand.Shuffle(len(addrs), func(i, j int) { + shuffler.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] }) // Pass the shuffled list to the resolver. From f14145e6d94541a46966f57e913d874ba80d0ef5 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Tue, 15 Sep 2020 14:11:48 -0400 Subject: [PATCH 8/8] agent/grpc: always close the conn when dialing fails. --- agent/grpc/client.go | 5 +++-- agent/grpc/client_test.go | 32 +++++++++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/agent/grpc/client.go b/agent/grpc/client.go index 71f16c7c31..783cbae36e 100644 --- a/agent/grpc/client.go +++ b/agent/grpc/client.go @@ -85,12 +85,13 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper) func(context.Context, server, err := servers.ServerForAddr(addr) if err != nil { - // TODO: should conn be closed in this case, as it is in other error cases? + conn.Close() return nil, err } if server.UseTLS { if wrapper == nil { + conn.Close() return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper") } @@ -111,7 +112,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper) func(context.Context, _, err = conn.Write([]byte{pool.RPCGRPC}) if err != nil { - // TODO: should conn be closed in this case, as it is in other error cases? + conn.Close() return nil, err } diff --git a/agent/grpc/client_test.go b/agent/grpc/client_test.go index 0cc8cb20ff..400e0a815d 100644 --- a/agent/grpc/client_test.go +++ b/agent/grpc/client_test.go @@ -3,6 +3,7 @@ package grpc import ( "context" "fmt" + "net" "strings" "testing" "time" @@ -14,11 +15,36 @@ import ( "github.com/stretchr/testify/require" ) -func TestNewDialer(t *testing.T) { - // TODO: conn is closed on errors - // TODO: with TLS enabled +func TestNewDialer_WithTLSWrapper(t *testing.T) { + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(logError(t, lis.Close)) + + builder := resolver.NewServerResolverBuilder(resolver.Config{}) + builder.AddServer(&metadata.Server{ + Name: "server-1", + ID: "ID1", + Datacenter: "dc1", + Addr: lis.Addr(), + UseTLS: true, + }) + + var called bool + wrapper := func(_ string, conn net.Conn) (net.Conn, error) { + called = true + return conn, nil + } + dial := newDialer(builder, wrapper) + ctx := context.Background() + conn, err := dial(ctx, lis.Addr().String()) + require.NoError(t, err) + require.NoError(t, conn.Close()) + require.True(t, called, "expected TLSWrapper to be called") } +// TODO: integration test TestNewDialer with TLS and rcp server, when the rpc +// exists as an isolated component. + func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { count := 4 cfg := resolver.Config{Scheme: newScheme(t.Name())}