diff --git a/agent/agent.go b/agent/agent.go index 33dc577c27..a963f4e32f 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -567,7 +567,7 @@ func (a *Agent) Start(ctx context.Context) error { // gRPC calls are only rate-limited on server, not client agents. var grpcRateLimiter rpcRate.RequestLimitsHandler - grpcRateLimiter = rpcRate.NullRateLimiter() + grpcRateLimiter = rpcRate.NullRequestLimitsHandler() if s, ok := a.delegate.(*consul.Server); ok { grpcRateLimiter = s.IncomingRPCLimiter() } diff --git a/agent/consul/rate/handler.go b/agent/consul/rate/handler.go index 0c9b0ccd25..165c04edb8 100644 --- a/agent/consul/rate/handler.go +++ b/agent/consul/rate/handler.go @@ -175,11 +175,15 @@ func (h *Handler) Run(ctx context.Context) { // Allow returns an error if the given operation is not allowed to proceed // because of an exhausted rate-limit. func (h *Handler) Allow(op Operation) error { - // TODO(NET-1383): actually implement the rate limiting logic. - // - // Example: - // if !h.limiter.Allow(globalWrite) { - // } + cfg := h.cfg.Load() + if cfg.GlobalMode == ModeDisabled { + return nil + } + + if !h.limiter.Allow(globalWrite) { + // TODO(NET-1383): actually implement the rate limiting logic and replace this returned nil. + return nil + } return nil } @@ -214,15 +218,15 @@ func (prefix globalLimit) Key() multilimiter.KeyType { return multilimiter.Key(prefix, nil) } -// NullRateLimiter returns a RateLimiter that allows every operation. -func NullRateLimiter() RequestLimitsHandler { - return nullRateLimiter{} +// NullRequestLimitsHandler returns a RequestLimitsHandler that allows every operation. +func NullRequestLimitsHandler() RequestLimitsHandler { + return nullRequestLimitsHandler{} } -type nullRateLimiter struct{} +type nullRequestLimitsHandler struct{} -func (nullRateLimiter) Allow(Operation) error { return nil } +func (nullRequestLimitsHandler) Allow(Operation) error { return nil } -func (nullRateLimiter) Run(ctx context.Context) {} +func (nullRequestLimitsHandler) Run(ctx context.Context) {} -func (nullRateLimiter) UpdateConfig(cfg HandlerConfig) {} +func (nullRequestLimitsHandler) UpdateConfig(cfg HandlerConfig) {} diff --git a/agent/consul/rate/handler_test.go b/agent/consul/rate/handler_test.go index 76fefe818e..29ce0ec6d1 100644 --- a/agent/consul/rate/handler_test.go +++ b/agent/consul/rate/handler_test.go @@ -1,6 +1,8 @@ package rate import ( + "net" + "net/netip" "testing" "github.com/hashicorp/consul/agent/consul/multilimiter" @@ -18,6 +20,7 @@ func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) { GlobalWriteConfig: writeCfg, GlobalMode: ModeEnforcing, } + logger := hclog.NewNullLogger() NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger) mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2) @@ -88,3 +91,59 @@ func TestUpdateConfig(t *testing.T) { }) } } + +func TestAllow(t *testing.T) { + readCfg := multilimiter.LimiterConfig{Rate: 100, Burst: 100} + writeCfg := multilimiter.LimiterConfig{Rate: 99, Burst: 99} + + type testCase struct { + description string + cfg *HandlerConfig + expectedAllowCalls int + } + testCases := []testCase{ + { + description: "RateLimiter does not get called when mode is disabled.", + cfg: &HandlerConfig{ + GlobalReadConfig: readCfg, + GlobalWriteConfig: writeCfg, + GlobalMode: ModeDisabled, + }, + expectedAllowCalls: 0, + }, + { + description: "RateLimiter gets called when mode is permissive.", + cfg: &HandlerConfig{ + GlobalReadConfig: readCfg, + GlobalWriteConfig: writeCfg, + GlobalMode: ModePermissive, + }, + expectedAllowCalls: 1, + }, + { + description: "RateLimiter gets called when mode is enforcing.", + cfg: &HandlerConfig{ + GlobalReadConfig: readCfg, + GlobalWriteConfig: writeCfg, + GlobalMode: ModeEnforcing, + }, + expectedAllowCalls: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + mockRateLimiter := multilimiter.NewMockRateLimiter(t) + if tc.expectedAllowCalls > 0 { + mockRateLimiter.On("Allow", mock.Anything).Return(func(entity multilimiter.LimitedEntity) bool { return true }) + } + mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() + logger := hclog.NewNullLogger() + handler := NewHandlerWithLimiter(*tc.cfg, nil, mockRateLimiter, logger) + addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234")) + mockRateLimiter.Calls = nil + handler.Allow(Operation{Name: "test", SourceAddr: addr}) + mockRateLimiter.AssertNumberOfCalls(t, "Allow", tc.expectedAllowCalls) + }) + } +} diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index 759a682f28..a4f8a90792 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -333,7 +333,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) { oldNotify() } } - grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, rpcRate.NullRateLimiter()) + grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, rpcRate.NullRequestLimitsHandler()) srv, err := NewServer(c, deps, grpcServer) if err != nil { return nil, err diff --git a/agent/grpc-external/stats_test.go b/agent/grpc-external/stats_test.go index 40d8509969..afe4ddfd0e 100644 --- a/agent/grpc-external/stats_test.go +++ b/agent/grpc-external/stats_test.go @@ -24,7 +24,7 @@ import ( func TestServer_EmitsStats(t *testing.T) { sink, metricsObj := testutil.NewFakeSink(t) - srv := NewServer(hclog.Default(), metricsObj, nil, rate.NullRateLimiter()) + srv := NewServer(hclog.Default(), metricsObj, nil, rate.NullRequestLimitsHandler()) testservice.RegisterSimpleServer(srv, &testservice.Simple{}) diff --git a/agent/grpc-internal/server_test.go b/agent/grpc-internal/server_test.go index 4e76ac5940..26166c1e05 100644 --- a/agent/grpc-internal/server_test.go +++ b/agent/grpc-internal/server_test.go @@ -55,7 +55,7 @@ func newPanicTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsC func newTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator, register func(server *grpc.Server)) testServer { addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} - handler := NewHandler(logger, addr, register, nil, rate.NullRateLimiter()) + handler := NewHandler(logger, addr, register, nil, rate.NullRequestLimitsHandler()) lis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) diff --git a/agent/grpc-internal/services/subscribe/subscribe_test.go b/agent/grpc-internal/services/subscribe/subscribe_test.go index 2ef6a41e82..72c8b5e1bf 100644 --- a/agent/grpc-internal/services/subscribe/subscribe_test.go +++ b/agent/grpc-internal/services/subscribe/subscribe_test.go @@ -381,7 +381,7 @@ func runTestServer(t *testing.T, server *Server) net.Addr { pbsubscribe.RegisterStateChangeSubscriptionServer(srv, server) }, nil, - rate.NullRateLimiter(), + rate.NullRequestLimitsHandler(), ) lis, err := net.Listen("tcp", "127.0.0.1:0") diff --git a/agent/grpc-internal/stats_test.go b/agent/grpc-internal/stats_test.go index 13f71b79b6..49059eaf4a 100644 --- a/agent/grpc-internal/stats_test.go +++ b/agent/grpc-internal/stats_test.go @@ -26,7 +26,7 @@ func TestHandler_EmitsStats(t *testing.T) { sink, metricsObj := testutil.NewFakeSink(t) addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} - handler := NewHandler(hclog.Default(), addr, noopRegister, metricsObj, rate.NullRateLimiter()) + handler := NewHandler(hclog.Default(), addr, noopRegister, metricsObj, rate.NullRequestLimitsHandler()) testservice.RegisterSimpleServer(handler.srv, &testservice.Simple{}) diff --git a/agent/rpc/peering/service_test.go b/agent/rpc/peering/service_test.go index 029bfa2c50..4bd177a108 100644 --- a/agent/rpc/peering/service_test.go +++ b/agent/rpc/peering/service_test.go @@ -1592,7 +1592,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer { conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta() deps := newDefaultDeps(t, conf) - externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, rate.NullRateLimiter()) + externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, rate.NullRequestLimitsHandler()) server, err := consul.NewServer(conf, deps, externalGRPCServer) require.NoError(t, err)