From 7747384f1f706746b42ef7edf0402b564b907b9a Mon Sep 17 00:00:00 2001 From: Dan Upton Date: Fri, 23 Dec 2022 19:42:16 +0000 Subject: [PATCH] Wire in rate limiter to handle internal and external gRPC calls (#15857) --- .gitignore | 1 + agent/agent.go | 36 +-- agent/consul/leader_connect_ca_test.go | 2 +- agent/consul/leader_test.go | 2 +- agent/consul/rate/handler.go | 110 ++++++-- agent/consul/rate/handler_test.go | 252 +++++++++++++++++- .../rate/mock_LeaderStatusProvider_test.go | 38 +++ agent/consul/server.go | 48 ++-- agent/consul/server_test.go | 10 +- agent/grpc-middleware/rate.go | 2 +- agent/rpc/peering/service_test.go | 2 +- 11 files changed, 435 insertions(+), 68 deletions(-) create mode 100644 agent/consul/rate/mock_LeaderStatusProvider_test.go diff --git a/.gitignore b/.gitignore index ade8cd97d1..5fbf456f38 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ Thumbs.db .idea .vscode __debug_bin +coverage.out # MacOS .DS_Store diff --git a/agent/agent.go b/agent/agent.go index a963f4e32f..0d581355d9 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -565,22 +565,6 @@ func (a *Agent) Start(ctx context.Context) error { return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err) } - // gRPC calls are only rate-limited on server, not client agents. - var grpcRateLimiter rpcRate.RequestLimitsHandler - grpcRateLimiter = rpcRate.NullRequestLimitsHandler() - if s, ok := a.delegate.(*consul.Server); ok { - grpcRateLimiter = s.IncomingRPCLimiter() - } - - // This needs to happen after the initial auto-config is loaded, because TLS - // can only be configured on the gRPC server at the point of creation. - a.externalGRPCServer = external.NewServer( - a.logger.Named("grpc.external"), - metrics.Default(), - a.tlsConfigurator, - grpcRateLimiter, - ) - if err := a.startLicenseManager(ctx); err != nil { return err } @@ -618,10 +602,21 @@ func (a *Agent) Start(ctx context.Context) error { // Setup either the client or the server. if c.ServerMode { - server, err := consul.NewServer(consulCfg, a.baseDeps.Deps, a.externalGRPCServer) + serverLogger := a.baseDeps.Logger.NamedIntercept(logging.ConsulServer) + incomingRPCLimiter := consul.ConfiguredIncomingRPCLimiter(serverLogger, consulCfg) + + a.externalGRPCServer = external.NewServer( + a.logger.Named("grpc.external"), + metrics.Default(), + a.tlsConfigurator, + incomingRPCLimiter, + ) + + server, err := consul.NewServer(consulCfg, a.baseDeps.Deps, a.externalGRPCServer, incomingRPCLimiter, serverLogger) if err != nil { return fmt.Errorf("Failed to start Consul server: %v", err) } + incomingRPCLimiter.Register(server) a.delegate = server if a.config.PeeringEnabled && a.config.ConnectEnabled { @@ -642,6 +637,13 @@ func (a *Agent) Start(ctx context.Context) error { } } else { + a.externalGRPCServer = external.NewServer( + a.logger.Named("grpc.external"), + metrics.Default(), + a.tlsConfigurator, + rpcRate.NullRequestLimitsHandler(), + ) + client, err := consul.NewClient(consulCfg, a.baseDeps.Deps) if err != nil { return fmt.Errorf("Failed to start Consul client: %v", err) diff --git a/agent/consul/leader_connect_ca_test.go b/agent/consul/leader_connect_ca_test.go index 8ffee0b67c..7e84a87b19 100644 --- a/agent/consul/leader_connect_ca_test.go +++ b/agent/consul/leader_connect_ca_test.go @@ -563,7 +563,7 @@ func TestCAManager_Initialize_Logging(t *testing.T) { deps := newDefaultDeps(t, conf1) deps.Logger = logger - s1, err := NewServer(conf1, deps, grpc.NewServer()) + s1, err := NewServer(conf1, deps, grpc.NewServer(), nil, logger) require.NoError(t, err) defer s1.Shutdown() testrpc.WaitForLeader(t, s1.RPC, "dc1") diff --git a/agent/consul/leader_test.go b/agent/consul/leader_test.go index 33094f59d5..0eaa33946d 100644 --- a/agent/consul/leader_test.go +++ b/agent/consul/leader_test.go @@ -1556,7 +1556,7 @@ func TestLeader_ConfigEntryBootstrap_Fail(t *testing.T) { deps := newDefaultDeps(t, config) deps.Logger = logger - srv, err := NewServer(config, deps, grpc.NewServer()) + srv, err := NewServer(config, deps, grpc.NewServer(), nil, logger) require.NoError(t, err) defer srv.Shutdown() diff --git a/agent/consul/rate/handler.go b/agent/consul/rate/handler.go index 165c04edb8..2e0708c767 100644 --- a/agent/consul/rate/handler.go +++ b/agent/consul/rate/handler.go @@ -4,6 +4,7 @@ package rate import ( "context" "errors" + "fmt" "net" "reflect" "sync/atomic" @@ -112,11 +113,14 @@ type RequestLimitsHandler interface { // Handler enforces rate limits for incoming RPCs. type Handler struct { - cfg *atomic.Pointer[HandlerConfig] - delegate HandlerDelegate + cfg *atomic.Pointer[HandlerConfig] + leaderStatusProvider LeaderStatusProvider limiter multilimiter.RateLimiter - logger hclog.Logger + + // TODO: replace this with the real logger. + // https://github.com/hashicorp/consul/pull/15822 + logger hclog.Logger } type HandlerConfig struct { @@ -135,7 +139,8 @@ type HandlerConfig struct { GlobalReadConfig multilimiter.LimiterConfig } -type HandlerDelegate interface { +//go:generate mockery --name LeaderStatusProvider --inpackage --filename mock_LeaderStatusProvider_test.go +type LeaderStatusProvider interface { // IsLeader is used to determine whether the operation is being performed // against the cluster leader, such that if it can _only_ be performed by // the leader (e.g. write operations) we don't tell clients to retry against @@ -143,16 +148,18 @@ type HandlerDelegate interface { IsLeader() bool } -func NewHandlerWithLimiter(cfg HandlerConfig, delegate HandlerDelegate, - limiter multilimiter.RateLimiter, logger hclog.Logger) *Handler { +func NewHandlerWithLimiter( + cfg HandlerConfig, + limiter multilimiter.RateLimiter, + logger hclog.Logger) *Handler { + limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite) limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead) h := &Handler{ - cfg: new(atomic.Pointer[HandlerConfig]), - delegate: delegate, - limiter: limiter, - logger: logger, + cfg: new(atomic.Pointer[HandlerConfig]), + limiter: limiter, + logger: logger, } h.cfg.Store(&cfg) @@ -160,9 +167,9 @@ func NewHandlerWithLimiter(cfg HandlerConfig, delegate HandlerDelegate, } // NewHandler creates a new RPC rate limit handler. -func NewHandler(cfg HandlerConfig, delegate HandlerDelegate, logger hclog.Logger) *Handler { +func NewHandler(cfg HandlerConfig, logger hclog.Logger) *Handler { limiter := multilimiter.NewMultiLimiter(cfg.Config) - return NewHandlerWithLimiter(cfg, delegate, limiter, logger) + return NewHandlerWithLimiter(cfg, limiter, logger) } // Run the limiter cleanup routine until the given context is canceled. @@ -175,14 +182,45 @@ func (h *Handler) Run(ctx context.Context) { // Allow returns an error if the given operation is not allowed to proceed // because of an exhausted rate-limit. func (h *Handler) Allow(op Operation) error { + + if h.leaderStatusProvider == nil { + h.logger.Error("leaderStatusProvider required to be set via Register(). bailing on rate limiter") + return nil + // TODO: panic and make sure to use the server's recovery handler + // panic("leaderStatusProvider required to be set via Register(..)") + } + cfg := h.cfg.Load() if cfg.GlobalMode == ModeDisabled { return nil } - if !h.limiter.Allow(globalWrite) { - // TODO(NET-1383): actually implement the rate limiting logic and replace this returned nil. - return nil + for _, l := range h.limits(op) { + if l.mode == ModeDisabled { + continue + } + + if h.limiter.Allow(l.ent) { + continue + } + + // TODO: metrics. + // TODO: is this the correct log-level? + + enforced := l.mode == ModeEnforcing + h.logger.Trace("RPC exceeded allowed rate limit", + "rpc", op.Name, + "source_addr", op.SourceAddr.String(), + "limit_type", l.desc, + "limit_enforced", enforced, + ) + + if enforced { + if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite { + return ErrRetryLater + } + return ErrRetryElsewhere + } } return nil } @@ -202,6 +240,48 @@ func (h *Handler) UpdateConfig(cfg HandlerConfig) { } } +func (h *Handler) Register(leaderStatusProvider LeaderStatusProvider) { + h.leaderStatusProvider = leaderStatusProvider +} + +type limit struct { + mode Mode + ent multilimiter.LimitedEntity + desc string +} + +// limits returns the limits to check for the given operation (e.g. global + +// ip-based + tenant-based). +func (h *Handler) limits(op Operation) []limit { + limits := make([]limit, 0) + + if global := h.globalLimit(op); global != nil { + limits = append(limits, *global) + } + + return limits +} + +func (h *Handler) globalLimit(op Operation) *limit { + if op.Type == OperationTypeExempt { + return nil + } + cfg := h.cfg.Load() + + lim := &limit{mode: cfg.GlobalMode} + switch op.Type { + case OperationTypeRead: + lim.desc = "global/read" + lim.ent = globalRead + case OperationTypeWrite: + lim.desc = "global/write" + lim.ent = globalWrite + default: + panic(fmt.Sprintf("unknown operation type %d", op.Type)) + } + return lim +} + var ( // globalWrite identifies the global rate limit applied to write operations. globalWrite = globalLimit("global.write") diff --git a/agent/consul/rate/handler_test.go b/agent/consul/rate/handler_test.go index 29ce0ec6d1..318907331f 100644 --- a/agent/consul/rate/handler_test.go +++ b/agent/consul/rate/handler_test.go @@ -1,15 +1,233 @@ package rate import ( + "bytes" + "context" "net" "net/netip" "testing" - "github.com/hashicorp/consul/agent/consul/multilimiter" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/hashicorp/go-hclog" - mock "github.com/stretchr/testify/mock" + + "github.com/hashicorp/consul/agent/consul/multilimiter" ) +// +// Revisit test when handler.go:189 TODO implemented +// +// func TestHandler_Allow_PanicsWhenLeaderStatusProviderNotRegistered(t *testing.T) { +// defer func() { +// err := recover() +// if err == nil { +// t.Fatal("Run should panic") +// } +// }() + +// handler := NewHandler(HandlerConfig{}, hclog.NewNullLogger()) +// handler.Allow(Operation{}) +// // intentionally skip handler.Register(...) +// } + +func TestHandler(t *testing.T) { + var ( + rpcName = "Foo.Bar" + sourceAddr = net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:5678")) + ) + + type limitCheck struct { + limit multilimiter.LimitedEntity + allow bool + } + testCases := map[string]struct { + op Operation + globalMode Mode + checks []limitCheck + isLeader bool + expectErr error + expectLog bool + }{ + "operation exempt from limiting": { + op: Operation{ + Type: OperationTypeExempt, + Name: rpcName, + SourceAddr: sourceAddr, + }, + globalMode: ModeEnforcing, + checks: []limitCheck{}, + expectErr: nil, + expectLog: false, + }, + "global write limit disabled": { + op: Operation{ + Type: OperationTypeWrite, + Name: rpcName, + SourceAddr: sourceAddr, + }, + globalMode: ModeDisabled, + checks: []limitCheck{}, + expectErr: nil, + expectLog: false, + }, + "global write limit within allowance": { + op: Operation{ + Type: OperationTypeWrite, + Name: rpcName, + SourceAddr: sourceAddr, + }, + globalMode: ModeEnforcing, + checks: []limitCheck{ + {limit: globalWrite, allow: true}, + }, + expectErr: nil, + expectLog: false, + }, + "global write limit exceeded (permissive)": { + op: Operation{ + Type: OperationTypeWrite, + Name: rpcName, + SourceAddr: sourceAddr, + }, + globalMode: ModePermissive, + checks: []limitCheck{ + {limit: globalWrite, allow: false}, + }, + expectErr: nil, + expectLog: true, + }, + "global write limit exceeded (enforcing, leader)": { + op: Operation{ + Type: OperationTypeWrite, + Name: rpcName, + SourceAddr: sourceAddr, + }, + globalMode: ModeEnforcing, + checks: []limitCheck{ + {limit: globalWrite, allow: false}, + }, + isLeader: true, + expectErr: ErrRetryLater, + expectLog: true, + }, + "global write limit exceeded (enforcing, follower)": { + op: Operation{ + Type: OperationTypeWrite, + Name: rpcName, + SourceAddr: sourceAddr, + }, + globalMode: ModeEnforcing, + checks: []limitCheck{ + {limit: globalWrite, allow: false}, + }, + isLeader: false, + expectErr: ErrRetryElsewhere, + expectLog: true, + }, + "global read limit disabled": { + op: Operation{ + Type: OperationTypeRead, + Name: rpcName, + SourceAddr: sourceAddr, + }, + globalMode: ModeDisabled, + checks: []limitCheck{}, + expectErr: nil, + expectLog: false, + }, + "global read limit within allowance": { + op: Operation{ + Type: OperationTypeRead, + Name: rpcName, + SourceAddr: sourceAddr, + }, + globalMode: ModeEnforcing, + checks: []limitCheck{ + {limit: globalRead, allow: true}, + }, + expectErr: nil, + expectLog: false, + }, + "global read limit exceeded (permissive)": { + op: Operation{ + Type: OperationTypeRead, + Name: rpcName, + SourceAddr: sourceAddr, + }, + globalMode: ModePermissive, + checks: []limitCheck{ + {limit: globalRead, allow: false}, + }, + expectErr: nil, + expectLog: true, + }, + "global read limit exceeded (enforcing, leader)": { + op: Operation{ + Type: OperationTypeRead, + Name: rpcName, + SourceAddr: sourceAddr, + }, + globalMode: ModeEnforcing, + checks: []limitCheck{ + {limit: globalRead, allow: false}, + }, + isLeader: true, + expectErr: ErrRetryElsewhere, + expectLog: true, + }, + "global read limit exceeded (enforcing, follower)": { + op: Operation{ + Type: OperationTypeRead, + Name: rpcName, + SourceAddr: sourceAddr, + }, + globalMode: ModeEnforcing, + checks: []limitCheck{ + {limit: globalRead, allow: false}, + }, + isLeader: false, + expectErr: ErrRetryElsewhere, + expectLog: true, + }, + } + for desc, tc := range testCases { + t.Run(desc, func(t *testing.T) { + limiter := newMockLimiter(t) + limiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() + for _, c := range tc.checks { + limiter.On("Allow", c.limit).Return(c.allow) + } + + leaderStatusProvider := NewMockLeaderStatusProvider(t) + leaderStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe() + + var output bytes.Buffer + logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{ + Level: hclog.Trace, + Output: &output, + }) + + handler := NewHandlerWithLimiter( + HandlerConfig{ + GlobalMode: tc.globalMode, + }, + limiter, + logger, + ) + handler.Register(leaderStatusProvider) + + require.Equal(t, tc.expectErr, handler.Allow(tc.op)) + + if tc.expectLog { + require.Contains(t, output.String(), "RPC exceeded allowed rate limit") + } else { + require.Zero(t, output.Len(), "expected no logs to be emitted") + } + }) + } +} + func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) { mockRateLimiter := multilimiter.NewMockRateLimiter(t) mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() @@ -22,7 +240,7 @@ func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) { } logger := hclog.NewNullLogger() - NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger) + NewHandlerWithLimiter(*cfg, mockRateLimiter, logger) mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2) } @@ -83,7 +301,7 @@ func TestUpdateConfig(t *testing.T) { mockRateLimiter := multilimiter.NewMockRateLimiter(t) mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() logger := hclog.NewNullLogger() - handler := NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger) + handler := NewHandlerWithLimiter(*cfg, mockRateLimiter, logger) mockRateLimiter.Calls = nil tc.configModFunc(cfg) handler.UpdateConfig(*cfg) @@ -139,7 +357,10 @@ func TestAllow(t *testing.T) { } mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() logger := hclog.NewNullLogger() - handler := NewHandlerWithLimiter(*tc.cfg, nil, mockRateLimiter, logger) + delegate := NewMockLeaderStatusProvider(t) + delegate.On("IsLeader").Return(true).Maybe() + handler := NewHandlerWithLimiter(*tc.cfg, mockRateLimiter, logger) + handler.Register(delegate) addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234")) mockRateLimiter.Calls = nil handler.Allow(Operation{Name: "test", SourceAddr: addr}) @@ -147,3 +368,24 @@ func TestAllow(t *testing.T) { }) } } + +var _ multilimiter.RateLimiter = (*mockLimiter)(nil) + +func newMockLimiter(t *testing.T) *mockLimiter { + l := &mockLimiter{} + l.Mock.Test(t) + + t.Cleanup(func() { l.AssertExpectations(t) }) + + return l +} + +type mockLimiter struct { + mock.Mock +} + +func (m *mockLimiter) Allow(v multilimiter.LimitedEntity) bool { return m.Called(v).Bool(0) } +func (m *mockLimiter) Run(ctx context.Context) { m.Called(ctx) } +func (m *mockLimiter) UpdateConfig(cfg multilimiter.LimiterConfig, prefix []byte) { + m.Called(cfg, prefix) +} diff --git a/agent/consul/rate/mock_LeaderStatusProvider_test.go b/agent/consul/rate/mock_LeaderStatusProvider_test.go new file mode 100644 index 0000000000..2c7f1b6cb8 --- /dev/null +++ b/agent/consul/rate/mock_LeaderStatusProvider_test.go @@ -0,0 +1,38 @@ +// Code generated by mockery v2.12.2. DO NOT EDIT. + +package rate + +import ( + testing "testing" + + mock "github.com/stretchr/testify/mock" +) + +// MockLeaderStatusProvider is an autogenerated mock type for the LeaderStatusProvider type +type MockLeaderStatusProvider struct { + mock.Mock +} + +// IsLeader provides a mock function with given fields: +func (_m *MockLeaderStatusProvider) IsLeader() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// NewMockLeaderStatusProvider creates a new instance of MockLeaderStatusProvider. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockLeaderStatusProvider(t testing.TB) *MockLeaderStatusProvider { + mock := &MockLeaderStatusProvider{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/consul/server.go b/agent/consul/server.go index 0aea2782ca..8580de9ba4 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -407,7 +407,7 @@ type connHandler interface { // NewServer is used to construct a new Consul server from the configuration // and extra options, potentially returning an error. -func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Server, error) { +func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incomingRPCLimiter rpcRate.RequestLimitsHandler, serverLogger hclog.InterceptLogger) (*Server, error) { logger := flat.Logger if err := config.CheckProtocolVersion(); err != nil { return nil, err @@ -428,7 +428,6 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser // Create the shutdown channel - this is closed but never written to. shutdownCh := make(chan struct{}) - serverLogger := flat.Logger.NamedIntercept(logging.ConsulServer) loggers := newLoggerStore(serverLogger) fsmDeps := fsm.Deps{ @@ -439,6 +438,10 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser Publisher: flat.EventPublisher, } + if incomingRPCLimiter == nil { + incomingRPCLimiter = rpcRate.NullRequestLimitsHandler() + } + // Create server. s := &Server{ config: config, @@ -463,6 +466,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser aclAuthMethodValidators: authmethod.NewCache(), fsm: fsm.NewFromDeps(fsmDeps), publisher: flat.EventPublisher, + incomingRPCLimiter: incomingRPCLimiter, } s.hcpManager = hcp.NewManager(hcp.ManagerConfig{ @@ -471,17 +475,6 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser Logger: logger.Named("hcp_manager"), }) - // TODO(NET-1380, NET-1381): thread this into the net/rpc and gRPC interceptors. - if s.incomingRPCLimiter == nil { - mlCfg := &multilimiter.Config{ReconcileCheckLimit: 30 * time.Second, ReconcileCheckInterval: time.Second} - limitsConfig := &RequestLimits{ - Mode: rpcRate.RequestLimitsModeFromNameWithDefault(config.RequestLimitsMode), - ReadRate: config.RequestLimitsReadRate, - WriteRate: config.RequestLimitsWriteRate, - } - - s.incomingRPCLimiter = rpcRate.NewHandler(*s.convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg), s, s.logger) - } s.incomingRPCLimiter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh}) var recorder *middleware.RequestRecorder @@ -1696,7 +1689,7 @@ func (s *Server) ReloadConfig(config ReloadableConfig) error { s.rpcLimiter.Store(rate.NewLimiter(config.RPCRateLimit, config.RPCMaxBurst)) if config.RequestLimits != nil { - s.incomingRPCLimiter.UpdateConfig(*s.convertConsulConfigToRateLimitHandlerConfig(*config.RequestLimits, nil)) + s.incomingRPCLimiter.UpdateConfig(*convertConsulConfigToRateLimitHandlerConfig(*config.RequestLimits, nil)) } s.rpcConnLimiter.SetConfig(connlimit.Config{ @@ -1849,9 +1842,25 @@ func (s *Server) hcpServerStatus(deps Deps) hcp.StatusCallback { } } -// convertConsulConfigToRateLimitHandlerConfig creates a rate limite handler config -// from the relevant fields in the consul runtime config. -func (s *Server) convertConsulConfigToRateLimitHandlerConfig(limitsConfig RequestLimits, multilimiterConfig *multilimiter.Config) *rpcRate.HandlerConfig { +func ConfiguredIncomingRPCLimiter(serverLogger hclog.InterceptLogger, consulCfg *Config) *rpcRate.Handler { + mlCfg := &multilimiter.Config{ReconcileCheckLimit: 30 * time.Second, ReconcileCheckInterval: time.Second} + limitsConfig := &RequestLimits{ + Mode: rpcRate.RequestLimitsModeFromNameWithDefault(consulCfg.RequestLimitsMode), + ReadRate: consulCfg.RequestLimitsReadRate, + WriteRate: consulCfg.RequestLimitsWriteRate, + } + + rateLimiterConfig := convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg) + + incomingRPCLimiter := rpcRate.NewHandler( + *rateLimiterConfig, + serverLogger.Named("rpc-rate-limit"), + ) + + return incomingRPCLimiter +} + +func convertConsulConfigToRateLimitHandlerConfig(limitsConfig RequestLimits, multilimiterConfig *multilimiter.Config) *rpcRate.HandlerConfig { hc := &rpcRate.HandlerConfig{ GlobalMode: limitsConfig.Mode, GlobalReadConfig: multilimiter.LimiterConfig{ @@ -1870,11 +1879,6 @@ func (s *Server) convertConsulConfigToRateLimitHandlerConfig(limitsConfig Reques return hc } -// IncomingRPCLimiter returns the server's configured rate limit handler for -// incoming RPCs. This is necessary because the external gRPC server is created -// by the agent (as it is also used for xDS). -func (s *Server) IncomingRPCLimiter() rpcRate.RequestLimitsHandler { return s.incomingRPCLimiter } - // peersInfoContent is used to help operators understand what happened to the // peers.json file. This is written to a file called peers.info in the same // location. diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index a4f8a90792..6084847fc8 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -334,7 +334,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) { } } grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, rpcRate.NullRequestLimitsHandler()) - srv, err := NewServer(c, deps, grpcServer) + srv, err := NewServer(c, deps, grpcServer, nil, deps.Logger) if err != nil { return nil, err } @@ -1241,7 +1241,7 @@ func TestServer_RPC_MetricsIntercept_Off(t *testing.T) { } } - s1, err := NewServer(conf, deps, grpc.NewServer()) + s1, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger) if err != nil { t.Fatalf("err: %v", err) } @@ -1279,7 +1279,7 @@ func TestServer_RPC_MetricsIntercept_Off(t *testing.T) { return nil } - s2, err := NewServer(conf, deps, grpc.NewServer()) + s2, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger) if err != nil { t.Fatalf("err: %v", err) } @@ -1313,7 +1313,7 @@ func TestServer_RPC_RequestRecorder(t *testing.T) { deps := newDefaultDeps(t, conf) deps.NewRequestRecorderFunc = nil - s1, err := NewServer(conf, deps, grpc.NewServer()) + s1, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger) require.Error(t, err, "need err when provider func is nil") require.Equal(t, err.Error(), "cannot initialize server without an RPC request recorder provider") @@ -1332,7 +1332,7 @@ func TestServer_RPC_RequestRecorder(t *testing.T) { return nil } - s2, err := NewServer(conf, deps, grpc.NewServer()) + s2, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger) require.Error(t, err, "need err when RequestRecorder is nil") require.Equal(t, err.Error(), "cannot initialize server with a nil RPC request recorder") diff --git a/agent/grpc-middleware/rate.go b/agent/grpc-middleware/rate.go index d2254000f9..149796cd9c 100644 --- a/agent/grpc-middleware/rate.go +++ b/agent/grpc-middleware/rate.go @@ -41,7 +41,7 @@ func ServerRateLimiterMiddleware(limiter rate.RequestLimitsHandler, panicHandler err := limiter.Allow(rate.Operation{ Name: info.FullMethodName, SourceAddr: peer.Addr, - // TODO: operation type. + // TODO: add operation type from https://github.com/hashicorp/consul/pull/15564 }) switch { diff --git a/agent/rpc/peering/service_test.go b/agent/rpc/peering/service_test.go index 4bd177a108..7d7cccc3b2 100644 --- a/agent/rpc/peering/service_test.go +++ b/agent/rpc/peering/service_test.go @@ -1594,7 +1594,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer { deps := newDefaultDeps(t, conf) externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, rate.NullRequestLimitsHandler()) - server, err := consul.NewServer(conf, deps, externalGRPCServer) + server, err := consul.NewServer(conf, deps, externalGRPCServer, nil, deps.Logger) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, server.Shutdown())