diff --git a/agent/consul/grpc_client.go b/agent/consul/grpc_client.go new file mode 100644 index 0000000000..6e6d3df115 --- /dev/null +++ b/agent/consul/grpc_client.go @@ -0,0 +1,117 @@ +package consul + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/hashicorp/go-hclog" + "google.golang.org/grpc" + + "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/agent/pool" + "github.com/hashicorp/consul/tlsutil" +) + +type ServerProvider interface { + Servers() []*metadata.Server +} + +type GRPCClient struct { + scheme string + serverProvider ServerProvider + tlsConfigurator *tlsutil.Configurator + grpcConns map[string]*grpc.ClientConn + grpcConnLock sync.Mutex +} + +func NewGRPCClient(logger hclog.Logger, serverProvider ServerProvider, tlsConfigurator *tlsutil.Configurator, scheme string) *GRPCClient { + // Note we don't actually use the logger anywhere yet but I guess it was added + // for future compatibility... + return &GRPCClient{ + scheme: scheme, + serverProvider: serverProvider, + tlsConfigurator: tlsConfigurator, + grpcConns: make(map[string]*grpc.ClientConn), + } +} + +func (c *GRPCClient) GRPCConn(datacenter string) (*grpc.ClientConn, error) { + c.grpcConnLock.Lock() + defer c.grpcConnLock.Unlock() + + // If there's an existing ClientConn for the given DC, return it. + if conn, ok := c.grpcConns[datacenter]; ok { + return conn, nil + } + + dialer := newDialer(c.serverProvider, c.tlsConfigurator.OutgoingRPCWrapper()) + conn, err := grpc.Dial(fmt.Sprintf("%s:///server.%s", c.scheme, datacenter), + // use WithInsecure mode here because we handle the TLS wrapping in the + // custom dialer based on logic around whether the server has TLS enabled. + grpc.WithInsecure(), + grpc.WithContextDialer(dialer), + grpc.WithDisableRetry(), + grpc.WithStatsHandler(grpcStatsHandler), + grpc.WithBalancerName("pick_first")) + if err != nil { + return nil, err + } + + c.grpcConns[datacenter] = conn + + return conn, nil +} + +// newDialer returns a gRPC dialer function that conditionally wraps the connection +// with TLS depending on the given useTLS value. +func newDialer(serverProvider ServerProvider, wrapper tlsutil.DCWrapper) func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, addr string) (net.Conn, error) { + d := net.Dialer{} + conn, err := d.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + + // Check if TLS is enabled for the server. + var found bool + var server *metadata.Server + for _, s := range serverProvider.Servers() { + if s.Addr.String() == addr { + found = true + server = s + } + } + if !found { + return nil, fmt.Errorf("could not find Consul server for address %q", addr) + } + + if server.UseTLS { + if wrapper == nil { + return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper") + } + + // Switch the connection into TLS mode + if _, err := conn.Write([]byte{byte(pool.RPCTLS)}); err != nil { + conn.Close() + return nil, err + } + + // Wrap the connection in a TLS client + tlsConn, err := wrapper(server.Datacenter, conn) + if err != nil { + conn.Close() + return nil, err + } + conn = tlsConn + } + + _, err = conn.Write([]byte{pool.RPCGRPC}) + if err != nil { + return nil, err + } + + return conn, nil + } +} diff --git a/agent/consul/grpc_resolver.go b/agent/consul/grpc_resolver.go new file mode 100644 index 0000000000..883bf4e24d --- /dev/null +++ b/agent/consul/grpc_resolver.go @@ -0,0 +1,240 @@ +package consul + +import ( + "math/rand" + "strings" + "sync" + "time" + + "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/agent/router" + "github.com/hashicorp/serf/serf" + "google.golang.org/grpc/resolver" +) + +var registerLock sync.Mutex + +// registerResolverBuilder registers our custom grpc resolver with the given scheme. +func registerResolverBuilder(scheme, datacenter string, shutdownCh <-chan struct{}) *ServerResolverBuilder { + registerLock.Lock() + defer registerLock.Unlock() + grpcResolverBuilder := NewServerResolverBuilder(scheme, datacenter, shutdownCh) + resolver.Register(grpcResolverBuilder) + return grpcResolverBuilder +} + +// ServerResolverBuilder tracks the current server list and keeps any +// ServerResolvers updated when changes occur. +type ServerResolverBuilder struct { + // Allow overriding the scheme to support parallel tests, since + // the resolver builder is registered globally. + scheme string + datacenter string + servers map[string]*metadata.Server + resolvers map[resolver.ClientConn]*ServerResolver + shutdownCh <-chan struct{} + lock sync.Mutex +} + +func NewServerResolverBuilder(scheme, datacenter string, shutdownCh <-chan struct{}) *ServerResolverBuilder { + return &ServerResolverBuilder{ + scheme: scheme, + datacenter: datacenter, + servers: make(map[string]*metadata.Server), + resolvers: make(map[resolver.ClientConn]*ServerResolver), + } +} + +// periodicServerRebalance periodically reshuffles the order of server addresses +// within the resolvers to ensure the load is balanced across servers. +func (s *ServerResolverBuilder) periodicServerRebalance(serf *serf.Serf) { + // Compute the rebalance timer based on the number of local servers and nodes. + rebalanceDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), serf.NumNodes()) + timer := time.NewTimer(rebalanceDuration) + + for { + select { + case <-timer.C: + s.rebalanceResolvers() + + // Re-compute the wait duration. + newTimerDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), serf.NumNodes()) + timer.Reset(newTimerDuration) + case <-s.shutdownCh: + timer.Stop() + return + } + } +} + +// rebalanceResolvers shuffles the server list for resolvers in all datacenters. +func (s *ServerResolverBuilder) rebalanceResolvers() { + s.lock.Lock() + defer s.lock.Unlock() + + for _, resolver := range s.resolvers { + // Shuffle the list of addresses using the last list given to the resolver. + resolver.addrLock.Lock() + addrs := resolver.lastAddrs + rand.Shuffle(len(addrs), func(i, j int) { + addrs[i], addrs[j] = addrs[j], addrs[i] + }) + // Pass the shuffled list to the resolver. + resolver.updateAddrsLocked(addrs) + resolver.addrLock.Unlock() + } +} + +// serversInDC returns the number of servers in the given datacenter. +func (s *ServerResolverBuilder) serversInDC(dc string) int { + s.lock.Lock() + defer s.lock.Unlock() + + var serverCount int + for _, server := range s.servers { + if server.Datacenter == dc { + serverCount++ + } + } + + return serverCount +} + +// Servers returns metadata for all currently known servers. This is used +// by grpc.ClientConn through our custom dialer. +func (s *ServerResolverBuilder) Servers() []*metadata.Server { + s.lock.Lock() + defer s.lock.Unlock() + + servers := make([]*metadata.Server, 0, len(s.servers)) + for _, server := range s.servers { + servers = append(servers, server) + } + return servers +} + +// Build returns a new ServerResolver for the given ClientConn. The resolver +// will keep the ClientConn's state updated based on updates from Serf. +func (s *ServerResolverBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) { + s.lock.Lock() + defer s.lock.Unlock() + + // If there's already a resolver for this datacenter, return it. + datacenter := strings.TrimPrefix(target.Endpoint, "server.") + if resolver, ok := s.resolvers[cc]; ok { + return resolver, nil + } + + // Make a new resolver for the dc and add it to the list of active ones. + resolver := &ServerResolver{ + datacenter: datacenter, + clientConn: cc, + } + resolver.updateAddrs(s.getDCAddrs(datacenter)) + resolver.closeCallback = func() { + s.lock.Lock() + defer s.lock.Unlock() + delete(s.resolvers, cc) + } + + s.resolvers[cc] = resolver + + return resolver, nil +} + +func (s *ServerResolverBuilder) Scheme() string { return s.scheme } + +// AddServer updates the resolvers' states to include the new server's address. +func (s *ServerResolverBuilder) AddServer(server *metadata.Server) { + s.lock.Lock() + defer s.lock.Unlock() + + s.servers[server.ID] = server + + addrs := s.getDCAddrs(server.Datacenter) + for _, resolver := range s.resolvers { + if resolver.datacenter == server.Datacenter { + resolver.updateAddrs(addrs) + } + } +} + +// RemoveServer updates the resolvers' states with the given server removed. +func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) { + s.lock.Lock() + defer s.lock.Unlock() + + delete(s.servers, server.ID) + + addrs := s.getDCAddrs(server.Datacenter) + for _, resolver := range s.resolvers { + if resolver.datacenter == server.Datacenter { + resolver.updateAddrs(addrs) + } + } +} + +// getDCAddrs returns a list of the server addresses for the given datacenter. +// This method assumes the lock is held. +func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address { + var addrs []resolver.Address + for _, server := range s.servers { + if server.Datacenter != dc { + continue + } + + addrs = append(addrs, resolver.Address{ + Addr: server.Addr.String(), + Type: resolver.Backend, + ServerName: server.Name, + }) + } + return addrs +} + +// ServerResolver is a grpc Resolver that will keep a grpc.ClientConn up to date +// on the list of server addresses to use. +type ServerResolver struct { + datacenter string + clientConn resolver.ClientConn + closeCallback func() + + lastAddrs []resolver.Address + addrLock sync.Mutex +} + +// updateAddrs updates this ServerResolver's ClientConn to use the given set of +// addrs. +func (r *ServerResolver) updateAddrs(addrs []resolver.Address) { + r.addrLock.Lock() + defer r.addrLock.Unlock() + r.updateAddrsLocked(addrs) +} + +// updateAddrsLocked updates this ServerResolver's ClientConn to use the given +// set of addrs. addrLock must be held by calleer. +func (r *ServerResolver) updateAddrsLocked(addrs []resolver.Address) { + // Only pass the first address initially, which will cause the + // balancer to spin down the connection for its previous first address + // if it is different. If we don't do this, it will keep using the old + // first address as long as it is still in the list, making it impossible to + // rebalance until that address is removed. + var firstAddr []resolver.Address + if len(addrs) > 0 { + firstAddr = []resolver.Address{addrs[0]} + } + r.clientConn.UpdateState(resolver.State{Addresses: firstAddr}) + + // Call UpdateState again with the entire list of addrs in case we need them + // for failover. + r.clientConn.UpdateState(resolver.State{Addresses: addrs}) + + r.lastAddrs = addrs +} + +func (s *ServerResolver) Close() { + s.closeCallback() +} + +// Unneeded since we only update the ClientConn when our server list changes. +func (*ServerResolver) ResolveNow(o resolver.ResolveNowOption) {} diff --git a/agent/router/grpc.go b/agent/router/grpc.go new file mode 100644 index 0000000000..0a50992811 --- /dev/null +++ b/agent/router/grpc.go @@ -0,0 +1,20 @@ +package router + +import "github.com/hashicorp/consul/agent/metadata" + +// ServerTracker is a wrapper around consul.ServerResolverBuilder to prevent a +// cyclic import dependency. +type ServerTracker interface { + AddServer(*metadata.Server) + RemoveServer(*metadata.Server) +} + +// NoOpServerTracker is a ServerTracker that does nothing. Used when gRPC is not +// enabled. +type NoOpServerTracker struct{} + +// AddServer implements ServerTracker +func (NoOpServerTracker) AddServer(*metadata.Server) {} + +// RemoveServer implements ServerTracker +func (NoOpServerTracker) RemoveServer(*metadata.Server) {} diff --git a/agent/router/manager.go b/agent/router/manager.go index 9c7d805976..0e64c5c904 100644 --- a/agent/router/manager.go +++ b/agent/router/manager.go @@ -98,6 +98,10 @@ type Manager struct { // client.ConnPool. connPoolPinger Pinger + // grpcServerTracker is used to balance grpc connections across servers, + // and has callbacks for adding or removing a server. + grpcServerTracker ServerTracker + // serverName has the name of the managers's server. This is used to // short-circuit pinging to itself. serverName string @@ -119,6 +123,7 @@ type Manager struct { func (m *Manager) AddServer(s *metadata.Server) { m.listLock.Lock() defer m.listLock.Unlock() + m.grpcServerTracker.AddServer(s) l := m.getServerList() // Check if this server is known @@ -251,6 +256,11 @@ func (m *Manager) CheckServers(fn func(srv *metadata.Server) bool) { _ = m.checkServers(fn) } +// Servers returns the current list of servers. +func (m *Manager) Servers() []*metadata.Server { + return m.getServerList().servers +} + // getServerList is a convenience method which hides the locking semantics // of atomic.Value from the caller. func (m *Manager) getServerList() serverList { @@ -267,15 +277,19 @@ func (m *Manager) saveServerList(l serverList) { } // New is the only way to safely create a new Manager struct. -func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, serverName string) (m *Manager) { +func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, tracker ServerTracker, serverName string) (m *Manager) { if logger == nil { logger = hclog.New(&hclog.LoggerOptions{}) } + if tracker == nil { + tracker = NoOpServerTracker{} + } m = new(Manager) m.logger = logger.Named(logging.Manager) m.clusterInfo = clusterInfo // can't pass *consul.Client: import cycle m.connPoolPinger = connPoolPinger // can't pass *consul.ConnPool: import cycle + m.grpcServerTracker = tracker m.rebalanceTimer = time.NewTimer(clientRPCMinReuseDuration) m.shutdownCh = shutdownCh m.serverName = serverName @@ -478,6 +492,7 @@ func (m *Manager) reconcileServerList(l *serverList) bool { func (m *Manager) RemoveServer(s *metadata.Server) { m.listLock.Lock() defer m.listLock.Unlock() + m.grpcServerTracker.RemoveServer(s) l := m.getServerList() // Remove the server if known @@ -498,17 +513,22 @@ func (m *Manager) RemoveServer(s *metadata.Server) { func (m *Manager) refreshServerRebalanceTimer() time.Duration { l := m.getServerList() numServers := len(l.servers) + connRebalanceTimeout := ComputeRebalanceTimer(numServers, m.clusterInfo.NumNodes()) + + m.rebalanceTimer.Reset(connRebalanceTimeout) + return connRebalanceTimeout +} + +// ComputeRebalanceTimer returns a time to wait before rebalancing connections given +// a number of servers and LAN nodes. +func ComputeRebalanceTimer(numServers, numLANMembers int) time.Duration { // Limit this connection's life based on the size (and health) of the // cluster. Never rebalance a connection more frequently than // connReuseLowWatermarkDuration, and make sure we never exceed // clusterWideRebalanceConnsPerSec operations/s across numLANMembers. clusterWideRebalanceConnsPerSec := float64(numServers * newRebalanceConnsPerSecPerServer) connReuseLowWatermarkDuration := clientRPCMinReuseDuration + lib.RandomStagger(clientRPCMinReuseDuration/clientRPCJitterFraction) - numLANMembers := m.clusterInfo.NumNodes() - connRebalanceTimeout := lib.RateScaledInterval(clusterWideRebalanceConnsPerSec, connReuseLowWatermarkDuration, numLANMembers) - - m.rebalanceTimer.Reset(connRebalanceTimeout) - return connRebalanceTimeout + return lib.RateScaledInterval(clusterWideRebalanceConnsPerSec, connReuseLowWatermarkDuration, numLANMembers) } // ResetRebalanceTimer resets the rebalance timer. This method exists for diff --git a/agent/router/manager_internal_test.go b/agent/router/manager_internal_test.go index 76d9512168..63838f4972 100644 --- a/agent/router/manager_internal_test.go +++ b/agent/router/manager_internal_test.go @@ -54,14 +54,14 @@ func (s *fauxSerf) NumNodes() int { func testManager() (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, "") + m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, nil, "") return m } func testManagerFailProb(failPct float64) (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "") + m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, nil, "") return m } @@ -300,7 +300,7 @@ func TestManagerInternal_refreshServerRebalanceTimer(t *testing.T) { shutdownCh := make(chan struct{}) for _, s := range clusters { - m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, "") + m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, nil, "") for i := 0; i < s.numServers; i++ { nodeName := fmt.Sprintf("s%02d", i) m.AddServer(&metadata.Server{Name: nodeName}) diff --git a/agent/router/manager_test.go b/agent/router/manager_test.go index c7e1f299ca..6c3a83816c 100644 --- a/agent/router/manager_test.go +++ b/agent/router/manager_test.go @@ -57,21 +57,21 @@ func (s *fauxSerf) NumNodes() int { func testManager(t testing.TB) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, nil, "") return m } func testManagerFailProb(t testing.TB, failPct float64) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, nil, "") return m } func testManagerFailAddr(t testing.TB, failAddr net.Addr) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, nil, "") return m } @@ -195,7 +195,7 @@ func TestServers_FindServer(t *testing.T) { func TestServers_New(t *testing.T) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "") + m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, nil, "") if m == nil { t.Fatalf("Manager nil") } diff --git a/agent/router/router.go b/agent/router/router.go index 027303bea2..63f6e88ee1 100644 --- a/agent/router/router.go +++ b/agent/router/router.go @@ -41,6 +41,10 @@ type Router struct { // routeFn is a hook to actually do the routing. routeFn func(datacenter string) (*Manager, *metadata.Server, bool) + // grpcServerTracker is used to balance grpc connections across servers, + // and has callbacks for adding or removing a server. + grpcServerTracker ServerTracker + // isShutdown prevents adding new routes to a router after it is shut // down. isShutdown bool @@ -87,17 +91,21 @@ type areaInfo struct { } // NewRouter returns a new Router with the given configuration. -func NewRouter(logger hclog.Logger, localDatacenter, serverName string) *Router { +func NewRouter(logger hclog.Logger, localDatacenter, serverName string, tracker ServerTracker) *Router { if logger == nil { logger = hclog.New(&hclog.LoggerOptions{}) } + if tracker == nil { + tracker = NoOpServerTracker{} + } router := &Router{ - logger: logger.Named(logging.Router), - localDatacenter: localDatacenter, - serverName: serverName, - areas: make(map[types.AreaID]*areaInfo), - managers: make(map[string][]*Manager), + logger: logger.Named(logging.Router), + localDatacenter: localDatacenter, + serverName: serverName, + areas: make(map[types.AreaID]*areaInfo), + managers: make(map[string][]*Manager), + grpcServerTracker: tracker, } // Hook the direct route lookup by default. @@ -251,7 +259,7 @@ func (r *Router) maybeInitializeManager(area *areaInfo, dc string) *Manager { } shutdownCh := make(chan struct{}) - manager := New(r.logger, shutdownCh, area.cluster, area.pinger, r.serverName) + manager := New(r.logger, shutdownCh, area.cluster, area.pinger, nil, r.serverName) info = &managerInfo{ manager: manager, shutdownCh: shutdownCh, diff --git a/agent/router/router_test.go b/agent/router/router_test.go index ae1beefaf4..83de54fed4 100644 --- a/agent/router/router_test.go +++ b/agent/router/router_test.go @@ -117,7 +117,7 @@ func testCluster(self string) *mockCluster { func testRouter(t testing.TB, dc string) *Router { logger := testutil.Logger(t) - return NewRouter(logger, dc, "") + return NewRouter(logger, dc, "", nil) } func TestRouter_Shutdown(t *testing.T) {