diff --git a/agent/agent.go b/agent/agent.go index be65616368..670ec8591a 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -563,12 +563,19 @@ func (a *Agent) Start(ctx context.Context) error { return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err) } + // gRPC calls are only rate-limited on server, not client agents. + grpcRateLimiter := middleware.NullRateLimiter() + if s, ok := a.delegate.(*consul.Server); ok { + grpcRateLimiter = s.IncomingRPCLimiter() + } + // This needs to happen after the initial auto-config is loaded, because TLS // can only be configured on the gRPC server at the point of creation. a.externalGRPCServer = external.NewServer( a.logger.Named("grpc.external"), metrics.Default(), a.tlsConfigurator, + grpcRateLimiter, ) if err := a.startLicenseManager(ctx); err != nil { diff --git a/agent/consul/server.go b/agent/consul/server.go index 0d4a0f0f30..3ac9fbb254 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -18,7 +18,6 @@ import ( "time" "github.com/armon/go-metrics" - "github.com/hashicorp/consul-net-rpc/net/rpc" "github.com/hashicorp/go-connlimit" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-memdb" @@ -30,6 +29,8 @@ import ( "golang.org/x/time/rate" "google.golang.org/grpc" + "github.com/hashicorp/consul-net-rpc/net/rpc" + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/consul/authmethod" "github.com/hashicorp/consul/agent/consul/authmethod/ssoauth" @@ -876,7 +877,7 @@ func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler s.externalConnectCAServer.Register(srv) } - return agentgrpc.NewHandler(deps.Logger, config.RPCAddr, register, nil) + return agentgrpc.NewHandler(deps.Logger, config.RPCAddr, register, nil, s.incomingRPCLimiter) } func (s *Server) connectCARootsMonitor(ctx context.Context) { @@ -1829,6 +1830,11 @@ func (s *Server) hcpServerStatus(deps Deps) hcp.StatusCallback { } } +// IncomingRPCLimiter returns the server's configured rate limit handler for +// incoming RPCs. This is necessary because the external gRPC server is created +// by the agent (as it is also used for xDS). +func (s *Server) IncomingRPCLimiter() *rpcRate.Handler { return s.incomingRPCLimiter } + // peersInfoContent is used to help operators understand what happened to the // peers.json file. This is written to a file called peers.info in the same // location. diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index 299b2cc158..b1be492cd2 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -331,7 +331,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) { oldNotify() } } - grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator) + grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, grpcmiddleware.NullRateLimiter()) srv, err := NewServer(c, deps, grpcServer) if err != nil { return nil, err diff --git a/agent/grpc-external/server.go b/agent/grpc-external/server.go index dd0186d480..98de599c8c 100644 --- a/agent/grpc-external/server.go +++ b/agent/grpc-external/server.go @@ -23,7 +23,7 @@ var ( // NewServer constructs a gRPC server for the external gRPC port, to which // handlers can be registered. -func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator) *grpc.Server { +func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator, limiter agentmiddleware.RateLimiter) *grpc.Server { if metricsObj == nil { metricsObj = metrics.Default() } @@ -48,6 +48,7 @@ func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls * opts := []grpc.ServerOption{ grpc.MaxConcurrentStreams(2048), grpc.MaxRecvMsgSize(50 * 1024 * 1024), + grpc.InTapHandle(agentmiddleware.ServerRateLimiterMiddleware(limiter, agentmiddleware.NewPanicHandler(logger))), grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)), middleware.WithUnaryServerChain(unaryInterceptors...), middleware.WithStreamServerChain(streamInterceptors...), diff --git a/agent/grpc-external/stats_test.go b/agent/grpc-external/stats_test.go index c231a922d1..c62eb65593 100644 --- a/agent/grpc-external/stats_test.go +++ b/agent/grpc-external/stats_test.go @@ -23,7 +23,7 @@ import ( func TestServer_EmitsStats(t *testing.T) { sink, metricsObj := testutil.NewFakeSink(t) - srv := NewServer(hclog.Default(), metricsObj, nil) + srv := NewServer(hclog.Default(), metricsObj, nil, grpcmiddleware.NullRateLimiter()) testservice.RegisterSimpleServer(srv, &testservice.Simple{}) diff --git a/agent/grpc-internal/handler.go b/agent/grpc-internal/handler.go index a576d656e7..fc563df4ed 100644 --- a/agent/grpc-internal/handler.go +++ b/agent/grpc-internal/handler.go @@ -6,6 +6,7 @@ import ( "time" "github.com/armon/go-metrics" + agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware" middleware "github.com/grpc-ecosystem/go-grpc-middleware" @@ -24,7 +25,7 @@ var ( // NewHandler returns a gRPC server that accepts connections from Handle(conn). // The register function will be called with the grpc.Server to register // gRPC services with the server. -func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server), metricsObj *metrics.Metrics) *Handler { +func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server), metricsObj *metrics.Metrics, rateLimiter agentmiddleware.RateLimiter) *Handler { if metricsObj == nil { metricsObj = metrics.Default() } @@ -34,6 +35,7 @@ func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server) recoveryOpts := agentmiddleware.PanicHandlerMiddlewareOpts(logger) opts := []grpc.ServerOption{ + grpc.InTapHandle(agentmiddleware.ServerRateLimiterMiddleware(rateLimiter, agentmiddleware.NewPanicHandler(logger))), grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)), middleware.WithUnaryServerChain( // Add middlware interceptors to recover in case of panics. diff --git a/agent/grpc-internal/server_test.go b/agent/grpc-internal/server_test.go index 56e18da1d6..1cd66ec0c0 100644 --- a/agent/grpc-internal/server_test.go +++ b/agent/grpc-internal/server_test.go @@ -14,6 +14,7 @@ import ( "golang.org/x/sync/errgroup" "google.golang.org/grpc" + middleware "github.com/hashicorp/consul/agent/grpc-middleware" "github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" @@ -54,7 +55,7 @@ func newPanicTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsC func newTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator, register func(server *grpc.Server)) testServer { addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} - handler := NewHandler(logger, addr, register, nil) + handler := NewHandler(logger, addr, register, nil, middleware.NullRateLimiter()) lis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) diff --git a/agent/grpc-internal/services/subscribe/subscribe_test.go b/agent/grpc-internal/services/subscribe/subscribe_test.go index 065b735407..c304543f41 100644 --- a/agent/grpc-internal/services/subscribe/subscribe_test.go +++ b/agent/grpc-internal/services/subscribe/subscribe_test.go @@ -22,6 +22,7 @@ import ( "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/stream" grpc "github.com/hashicorp/consul/agent/grpc-internal" + middleware "github.com/hashicorp/consul/agent/grpc-middleware" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/proto/pbcommon" @@ -380,6 +381,7 @@ func runTestServer(t *testing.T, server *Server) net.Addr { pbsubscribe.RegisterStateChangeSubscriptionServer(srv, server) }, nil, + middleware.NullRateLimiter(), ) lis, err := net.Listen("tcp", "127.0.0.1:0") diff --git a/agent/grpc-internal/stats_test.go b/agent/grpc-internal/stats_test.go index a304de870e..2672beba46 100644 --- a/agent/grpc-internal/stats_test.go +++ b/agent/grpc-internal/stats_test.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/go-hclog" + middleware "github.com/hashicorp/consul/agent/grpc-middleware" "github.com/hashicorp/consul/agent/grpc-middleware/testutil" "github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" "github.com/hashicorp/consul/proto/prototest" @@ -25,7 +26,7 @@ func TestHandler_EmitsStats(t *testing.T) { sink, metricsObj := testutil.NewFakeSink(t) addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} - handler := NewHandler(hclog.Default(), addr, noopRegister, metricsObj) + handler := NewHandler(hclog.Default(), addr, noopRegister, metricsObj, middleware.NullRateLimiter()) testservice.RegisterSimpleServer(handler.srv, &testservice.Simple{}) diff --git a/agent/grpc-middleware/mock_RateLimiter.go b/agent/grpc-middleware/mock_RateLimiter.go new file mode 100644 index 0000000000..9f427b7bc6 --- /dev/null +++ b/agent/grpc-middleware/mock_RateLimiter.go @@ -0,0 +1,39 @@ +// Code generated by mockery v2.12.0. DO NOT EDIT. + +package middleware + +import ( + testing "testing" + + rate "github.com/hashicorp/consul/agent/consul/rate" + mock "github.com/stretchr/testify/mock" +) + +// MockRateLimiter is an autogenerated mock type for the RateLimiter type +type MockRateLimiter struct { + mock.Mock +} + +// Allow provides a mock function with given fields: _a0 +func (_m *MockRateLimiter) Allow(_a0 rate.Operation) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(rate.Operation) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewMockRateLimiter creates a new instance of MockRateLimiter. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockRateLimiter(t testing.TB) *MockRateLimiter { + mock := &MockRateLimiter{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/grpc-middleware/rate.go b/agent/grpc-middleware/rate.go new file mode 100644 index 0000000000..5683d69404 --- /dev/null +++ b/agent/grpc-middleware/rate.go @@ -0,0 +1,72 @@ +package middleware + +import ( + "context" + "errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" + "google.golang.org/grpc/tap" + + recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" + + "github.com/hashicorp/consul/agent/consul/rate" +) + +// ServerRateLimiterMiddleware implements a ServerInHandle function to perform +// RPC rate limiting at the cheapest possible point (before the full request has +// been decoded). +func ServerRateLimiterMiddleware(limiter RateLimiter, panicHandler recovery.RecoveryHandlerFunc) tap.ServerInHandle { + return func(ctx context.Context, info *tap.Info) (_ context.Context, retErr error) { + // This function is called before unary and stream RPC interceptors, so we + // must handle our own panics here. + defer func() { + if r := recover(); r != nil { + retErr = panicHandler(r) + } + }() + + // Do not rate-limit the xDS service, it handles its own limiting. + if info.FullMethodName == "/envoy.service.discovery.v3.AggregatedDiscoveryService/DeltaAggregatedResources" { + return ctx, nil + } + + peer, ok := peer.FromContext(ctx) + if !ok { + // This should never happen! + return ctx, status.Error(codes.Internal, "gRPC rate limit middleware unable to read peer") + } + + err := limiter.Allow(rate.Operation{ + Name: info.FullMethodName, + SourceAddr: peer.Addr, + // TODO: operation type. + }) + + switch { + case err == nil: + return ctx, nil + case errors.Is(err, rate.ErrRetryElsewhere): + return ctx, status.Error(codes.ResourceExhausted, err.Error()) + case errors.Is(err, rate.ErrRetryLater): + return ctx, status.Error(codes.Unavailable, err.Error()) + default: + return ctx, status.Error(codes.Internal, err.Error()) + } + } +} + +//go:generate mockery --name RateLimiter --inpackage +type RateLimiter interface { + Allow(rate.Operation) error +} + +// NullRateLimiter returns a RateLimiter that allows every operation. +func NullRateLimiter() RateLimiter { + return nullRateLimiter{} +} + +type nullRateLimiter struct{} + +func (nullRateLimiter) Allow(rate.Operation) error { return nil } diff --git a/agent/grpc-middleware/rate_test.go b/agent/grpc-middleware/rate_test.go new file mode 100644 index 0000000000..1c5d417047 --- /dev/null +++ b/agent/grpc-middleware/rate_test.go @@ -0,0 +1,111 @@ +package middleware + +import ( + "context" + "errors" + "net" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" + + "github.com/hashicorp/go-hclog" + + "github.com/hashicorp/consul/agent/consul/rate" +) + +func TestServerRateLimiterMiddleware_Integration(t *testing.T) { + limiter := NewMockRateLimiter(t) + + server := grpc.NewServer( + grpc.InTapHandle(ServerRateLimiterMiddleware(limiter, NewPanicHandler(hclog.NewNullLogger()))), + ) + server.RegisterService(&healthpb.Health_ServiceDesc, health.NewServer()) + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + if err := lis.Close(); err != nil { + t.Logf("failed to close listener: %v", err) + } + }) + go server.Serve(lis) + t.Cleanup(server.Stop) + + conn, err := grpc.Dial( + lis.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + t.Cleanup(func() { + if err := conn.Close(); err != nil { + t.Logf("failed to close client connection: %v", err) + } + }) + client := healthpb.NewHealthClient(conn) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + t.Run("ErrRetryElsewhere = ResourceExhausted", func(t *testing.T) { + limiter.On("Allow", mock.Anything). + Run(func(args mock.Arguments) { + op := args.Get(0).(rate.Operation) + require.Equal(t, "/grpc.health.v1.Health/Check", op.Name) + + addr := op.SourceAddr.(*net.TCPAddr) + require.True(t, addr.IP.IsLoopback()) + }). + Return(rate.ErrRetryElsewhere). + Once() + + _, err = client.Check(ctx, &healthpb.HealthCheckRequest{}) + require.Error(t, err) + require.Equal(t, codes.ResourceExhausted.String(), status.Code(err).String()) + }) + + t.Run("ErrRetryLater = Unavailable", func(t *testing.T) { + limiter.On("Allow", mock.Anything). + Return(rate.ErrRetryLater). + Once() + + _, err = client.Check(ctx, &healthpb.HealthCheckRequest{}) + require.Error(t, err) + require.Equal(t, codes.Unavailable.String(), status.Code(err).String()) + }) + + t.Run("unexpected error", func(t *testing.T) { + limiter.On("Allow", mock.Anything). + Return(errors.New("uh oh")). + Once() + + _, err = client.Check(ctx, &healthpb.HealthCheckRequest{}) + require.Error(t, err) + require.Equal(t, codes.Internal.String(), status.Code(err).String()) + }) + + t.Run("operation allowed", func(t *testing.T) { + limiter.On("Allow", mock.Anything). + Return(nil). + Once() + + _, err = client.Check(ctx, &healthpb.HealthCheckRequest{}) + require.NoError(t, err) + }) + + t.Run("Allow panics", func(t *testing.T) { + limiter.On("Allow", mock.Anything). + Panic("uh oh"). + Once() + + _, err = client.Check(ctx, &healthpb.HealthCheckRequest{}) + require.Error(t, err) + require.Equal(t, codes.Internal.String(), status.Code(err).String()) + }) +} diff --git a/agent/rpc/peering/service_test.go b/agent/rpc/peering/service_test.go index faae32b632..c971ccf72c 100644 --- a/agent/rpc/peering/service_test.go +++ b/agent/rpc/peering/service_test.go @@ -1591,7 +1591,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer { conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta() deps := newDefaultDeps(t, conf) - externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator) + externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, agentmiddleware.NullRateLimiter()) server, err := consul.NewServer(conf, deps, externalGRPCServer) require.NoError(t, err)