grpc: add rate-limiting middleware (#15550)

Implements the gRPC middleware for rate-limiting as a tap.ServerInHandle
function (executed before the request is unmarshaled).

Mappings between gRPC methods and their operation type are generated by
a protoc plugin introduced by #15564.
This commit is contained in:
Dan Upton 2022-12-13 15:01:56 +00:00 committed by GitHub
parent eef38c2199
commit c692802dec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 251 additions and 9 deletions

View File

@ -563,12 +563,19 @@ func (a *Agent) Start(ctx context.Context) error {
return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err) return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err)
} }
// gRPC calls are only rate-limited on server, not client agents.
grpcRateLimiter := middleware.NullRateLimiter()
if s, ok := a.delegate.(*consul.Server); ok {
grpcRateLimiter = s.IncomingRPCLimiter()
}
// This needs to happen after the initial auto-config is loaded, because TLS // This needs to happen after the initial auto-config is loaded, because TLS
// can only be configured on the gRPC server at the point of creation. // can only be configured on the gRPC server at the point of creation.
a.externalGRPCServer = external.NewServer( a.externalGRPCServer = external.NewServer(
a.logger.Named("grpc.external"), a.logger.Named("grpc.external"),
metrics.Default(), metrics.Default(),
a.tlsConfigurator, a.tlsConfigurator,
grpcRateLimiter,
) )
if err := a.startLicenseManager(ctx); err != nil { if err := a.startLicenseManager(ctx); err != nil {

View File

@ -18,7 +18,6 @@ import (
"time" "time"
"github.com/armon/go-metrics" "github.com/armon/go-metrics"
"github.com/hashicorp/consul-net-rpc/net/rpc"
"github.com/hashicorp/go-connlimit" "github.com/hashicorp/go-connlimit"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
@ -30,6 +29,8 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/hashicorp/consul-net-rpc/net/rpc"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/authmethod" "github.com/hashicorp/consul/agent/consul/authmethod"
"github.com/hashicorp/consul/agent/consul/authmethod/ssoauth" "github.com/hashicorp/consul/agent/consul/authmethod/ssoauth"
@ -876,7 +877,7 @@ func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler
s.externalConnectCAServer.Register(srv) s.externalConnectCAServer.Register(srv)
} }
return agentgrpc.NewHandler(deps.Logger, config.RPCAddr, register, nil) return agentgrpc.NewHandler(deps.Logger, config.RPCAddr, register, nil, s.incomingRPCLimiter)
} }
func (s *Server) connectCARootsMonitor(ctx context.Context) { func (s *Server) connectCARootsMonitor(ctx context.Context) {
@ -1829,6 +1830,11 @@ func (s *Server) hcpServerStatus(deps Deps) hcp.StatusCallback {
} }
} }
// IncomingRPCLimiter returns the server's configured rate limit handler for
// incoming RPCs. This is necessary because the external gRPC server is created
// by the agent (as it is also used for xDS).
func (s *Server) IncomingRPCLimiter() *rpcRate.Handler { return s.incomingRPCLimiter }
// peersInfoContent is used to help operators understand what happened to the // peersInfoContent is used to help operators understand what happened to the
// peers.json file. This is written to a file called peers.info in the same // peers.json file. This is written to a file called peers.info in the same
// location. // location.

View File

@ -331,7 +331,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) {
oldNotify() oldNotify()
} }
} }
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator) grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, grpcmiddleware.NullRateLimiter())
srv, err := NewServer(c, deps, grpcServer) srv, err := NewServer(c, deps, grpcServer)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -23,7 +23,7 @@ var (
// NewServer constructs a gRPC server for the external gRPC port, to which // NewServer constructs a gRPC server for the external gRPC port, to which
// handlers can be registered. // handlers can be registered.
func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator) *grpc.Server { func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator, limiter agentmiddleware.RateLimiter) *grpc.Server {
if metricsObj == nil { if metricsObj == nil {
metricsObj = metrics.Default() metricsObj = metrics.Default()
} }
@ -48,6 +48,7 @@ func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *
opts := []grpc.ServerOption{ opts := []grpc.ServerOption{
grpc.MaxConcurrentStreams(2048), grpc.MaxConcurrentStreams(2048),
grpc.MaxRecvMsgSize(50 * 1024 * 1024), grpc.MaxRecvMsgSize(50 * 1024 * 1024),
grpc.InTapHandle(agentmiddleware.ServerRateLimiterMiddleware(limiter, agentmiddleware.NewPanicHandler(logger))),
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)), grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
middleware.WithUnaryServerChain(unaryInterceptors...), middleware.WithUnaryServerChain(unaryInterceptors...),
middleware.WithStreamServerChain(streamInterceptors...), middleware.WithStreamServerChain(streamInterceptors...),

View File

@ -23,7 +23,7 @@ import (
func TestServer_EmitsStats(t *testing.T) { func TestServer_EmitsStats(t *testing.T) {
sink, metricsObj := testutil.NewFakeSink(t) sink, metricsObj := testutil.NewFakeSink(t)
srv := NewServer(hclog.Default(), metricsObj, nil) srv := NewServer(hclog.Default(), metricsObj, nil, grpcmiddleware.NullRateLimiter())
testservice.RegisterSimpleServer(srv, &testservice.Simple{}) testservice.RegisterSimpleServer(srv, &testservice.Simple{})

View File

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/armon/go-metrics" "github.com/armon/go-metrics"
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware" agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
middleware "github.com/grpc-ecosystem/go-grpc-middleware" middleware "github.com/grpc-ecosystem/go-grpc-middleware"
@ -24,7 +25,7 @@ var (
// NewHandler returns a gRPC server that accepts connections from Handle(conn). // NewHandler returns a gRPC server that accepts connections from Handle(conn).
// The register function will be called with the grpc.Server to register // The register function will be called with the grpc.Server to register
// gRPC services with the server. // gRPC services with the server.
func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server), metricsObj *metrics.Metrics) *Handler { func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server), metricsObj *metrics.Metrics, rateLimiter agentmiddleware.RateLimiter) *Handler {
if metricsObj == nil { if metricsObj == nil {
metricsObj = metrics.Default() metricsObj = metrics.Default()
} }
@ -34,6 +35,7 @@ func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server)
recoveryOpts := agentmiddleware.PanicHandlerMiddlewareOpts(logger) recoveryOpts := agentmiddleware.PanicHandlerMiddlewareOpts(logger)
opts := []grpc.ServerOption{ opts := []grpc.ServerOption{
grpc.InTapHandle(agentmiddleware.ServerRateLimiterMiddleware(rateLimiter, agentmiddleware.NewPanicHandler(logger))),
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)), grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
middleware.WithUnaryServerChain( middleware.WithUnaryServerChain(
// Add middlware interceptors to recover in case of panics. // Add middlware interceptors to recover in case of panics.

View File

@ -14,6 +14,7 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"google.golang.org/grpc" "google.golang.org/grpc"
middleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" "github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
"github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/pool"
@ -54,7 +55,7 @@ func newPanicTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsC
func newTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator, register func(server *grpc.Server)) testServer { func newTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator, register func(server *grpc.Server)) testServer {
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
handler := NewHandler(logger, addr, register, nil) handler := NewHandler(logger, addr, register, nil, middleware.NullRateLimiter())
lis, err := net.Listen("tcp", "127.0.0.1:0") lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err) require.NoError(t, err)

View File

@ -22,6 +22,7 @@ import (
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/consul/stream"
grpc "github.com/hashicorp/consul/agent/grpc-internal" grpc "github.com/hashicorp/consul/agent/grpc-internal"
middleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/proto/pbcommon" "github.com/hashicorp/consul/proto/pbcommon"
@ -380,6 +381,7 @@ func runTestServer(t *testing.T, server *Server) net.Addr {
pbsubscribe.RegisterStateChangeSubscriptionServer(srv, server) pbsubscribe.RegisterStateChangeSubscriptionServer(srv, server)
}, },
nil, nil,
middleware.NullRateLimiter(),
) )
lis, err := net.Listen("tcp", "127.0.0.1:0") lis, err := net.Listen("tcp", "127.0.0.1:0")

View File

@ -14,6 +14,7 @@ import (
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
middleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/agent/grpc-middleware/testutil" "github.com/hashicorp/consul/agent/grpc-middleware/testutil"
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" "github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
"github.com/hashicorp/consul/proto/prototest" "github.com/hashicorp/consul/proto/prototest"
@ -25,7 +26,7 @@ func TestHandler_EmitsStats(t *testing.T) {
sink, metricsObj := testutil.NewFakeSink(t) sink, metricsObj := testutil.NewFakeSink(t)
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
handler := NewHandler(hclog.Default(), addr, noopRegister, metricsObj) handler := NewHandler(hclog.Default(), addr, noopRegister, metricsObj, middleware.NullRateLimiter())
testservice.RegisterSimpleServer(handler.srv, &testservice.Simple{}) testservice.RegisterSimpleServer(handler.srv, &testservice.Simple{})

View File

@ -0,0 +1,39 @@
// Code generated by mockery v2.12.0. DO NOT EDIT.
package middleware
import (
testing "testing"
rate "github.com/hashicorp/consul/agent/consul/rate"
mock "github.com/stretchr/testify/mock"
)
// MockRateLimiter is an autogenerated mock type for the RateLimiter type
type MockRateLimiter struct {
mock.Mock
}
// Allow provides a mock function with given fields: _a0
func (_m *MockRateLimiter) Allow(_a0 rate.Operation) error {
ret := _m.Called(_a0)
var r0 error
if rf, ok := ret.Get(0).(func(rate.Operation) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewMockRateLimiter creates a new instance of MockRateLimiter. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockRateLimiter(t testing.TB) *MockRateLimiter {
mock := &MockRateLimiter{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -0,0 +1,72 @@
package middleware
import (
"context"
"errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"google.golang.org/grpc/tap"
recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"github.com/hashicorp/consul/agent/consul/rate"
)
// ServerRateLimiterMiddleware implements a ServerInHandle function to perform
// RPC rate limiting at the cheapest possible point (before the full request has
// been decoded).
func ServerRateLimiterMiddleware(limiter RateLimiter, panicHandler recovery.RecoveryHandlerFunc) tap.ServerInHandle {
return func(ctx context.Context, info *tap.Info) (_ context.Context, retErr error) {
// This function is called before unary and stream RPC interceptors, so we
// must handle our own panics here.
defer func() {
if r := recover(); r != nil {
retErr = panicHandler(r)
}
}()
// Do not rate-limit the xDS service, it handles its own limiting.
if info.FullMethodName == "/envoy.service.discovery.v3.AggregatedDiscoveryService/DeltaAggregatedResources" {
return ctx, nil
}
peer, ok := peer.FromContext(ctx)
if !ok {
// This should never happen!
return ctx, status.Error(codes.Internal, "gRPC rate limit middleware unable to read peer")
}
err := limiter.Allow(rate.Operation{
Name: info.FullMethodName,
SourceAddr: peer.Addr,
// TODO: operation type.
})
switch {
case err == nil:
return ctx, nil
case errors.Is(err, rate.ErrRetryElsewhere):
return ctx, status.Error(codes.ResourceExhausted, err.Error())
case errors.Is(err, rate.ErrRetryLater):
return ctx, status.Error(codes.Unavailable, err.Error())
default:
return ctx, status.Error(codes.Internal, err.Error())
}
}
}
//go:generate mockery --name RateLimiter --inpackage
type RateLimiter interface {
Allow(rate.Operation) error
}
// NullRateLimiter returns a RateLimiter that allows every operation.
func NullRateLimiter() RateLimiter {
return nullRateLimiter{}
}
type nullRateLimiter struct{}
func (nullRateLimiter) Allow(rate.Operation) error { return nil }

View File

@ -0,0 +1,111 @@
package middleware
import (
"context"
"errors"
"net"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/health"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/status"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/agent/consul/rate"
)
func TestServerRateLimiterMiddleware_Integration(t *testing.T) {
limiter := NewMockRateLimiter(t)
server := grpc.NewServer(
grpc.InTapHandle(ServerRateLimiterMiddleware(limiter, NewPanicHandler(hclog.NewNullLogger()))),
)
server.RegisterService(&healthpb.Health_ServiceDesc, health.NewServer())
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
if err := lis.Close(); err != nil {
t.Logf("failed to close listener: %v", err)
}
})
go server.Serve(lis)
t.Cleanup(server.Stop)
conn, err := grpc.Dial(
lis.Addr().String(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
require.NoError(t, err)
t.Cleanup(func() {
if err := conn.Close(); err != nil {
t.Logf("failed to close client connection: %v", err)
}
})
client := healthpb.NewHealthClient(conn)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
t.Run("ErrRetryElsewhere = ResourceExhausted", func(t *testing.T) {
limiter.On("Allow", mock.Anything).
Run(func(args mock.Arguments) {
op := args.Get(0).(rate.Operation)
require.Equal(t, "/grpc.health.v1.Health/Check", op.Name)
addr := op.SourceAddr.(*net.TCPAddr)
require.True(t, addr.IP.IsLoopback())
}).
Return(rate.ErrRetryElsewhere).
Once()
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
require.Error(t, err)
require.Equal(t, codes.ResourceExhausted.String(), status.Code(err).String())
})
t.Run("ErrRetryLater = Unavailable", func(t *testing.T) {
limiter.On("Allow", mock.Anything).
Return(rate.ErrRetryLater).
Once()
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
require.Error(t, err)
require.Equal(t, codes.Unavailable.String(), status.Code(err).String())
})
t.Run("unexpected error", func(t *testing.T) {
limiter.On("Allow", mock.Anything).
Return(errors.New("uh oh")).
Once()
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
require.Error(t, err)
require.Equal(t, codes.Internal.String(), status.Code(err).String())
})
t.Run("operation allowed", func(t *testing.T) {
limiter.On("Allow", mock.Anything).
Return(nil).
Once()
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
require.NoError(t, err)
})
t.Run("Allow panics", func(t *testing.T) {
limiter.On("Allow", mock.Anything).
Panic("uh oh").
Once()
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
require.Error(t, err)
require.Equal(t, codes.Internal.String(), status.Code(err).String())
})
}

View File

@ -1591,7 +1591,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer {
conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta() conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta()
deps := newDefaultDeps(t, conf) deps := newDefaultDeps(t, conf)
externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator) externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, agentmiddleware.NullRateLimiter())
server, err := consul.NewServer(conf, deps, externalGRPCServer) server, err := consul.NewServer(conf, deps, externalGRPCServer)
require.NoError(t, err) require.NoError(t, err)