Rate Limit Handler - ensure rate limiting is not in the code path when not configured (#15819)

* Rate limiting handler - ensure configuration has changed before modifying limiters

* Updating test to validate arguments to UpdateConfig

* Removing duplicate test.  Updating mock.

* Renaming NullRateLimiter to NullRequestLimitsHandler

* Rate Limit Handler - ensure rate limiting is not in the code path when not configured

* Update agent/consul/rate/handler.go

Co-authored-by: Dhia Ayachi <dhia@hashicorp.com>

* formatting handler.go

* Rate limiting handler - ensure configuration has changed before modifying limiters

* Updating test to validate arguments to UpdateConfig

* Removing duplicate test.  Updating mock.

* adding logging for when UpdateConfig is called but the config has not changed.

* Update agent/consul/rate/handler.go

Co-authored-by: Dhia Ayachi <dhia@hashicorp.com>

* Update agent/consul/rate/handler_test.go

Co-authored-by: Dan Upton <daniel@floppy.co>

* modifying existing variable name based on pr feedback

* updating a broken merge conflict;

Co-authored-by: Dhia Ayachi <dhia@hashicorp.com>
Co-authored-by: Dan Upton <daniel@floppy.co>
This commit is contained in:
John Murret 2022-12-20 15:00:22 -07:00 committed by GitHub
parent aba43d85d9
commit f5e01f8c6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 82 additions and 19 deletions

View File

@ -567,7 +567,7 @@ func (a *Agent) Start(ctx context.Context) error {
// gRPC calls are only rate-limited on server, not client agents. // gRPC calls are only rate-limited on server, not client agents.
var grpcRateLimiter rpcRate.RequestLimitsHandler var grpcRateLimiter rpcRate.RequestLimitsHandler
grpcRateLimiter = rpcRate.NullRateLimiter() grpcRateLimiter = rpcRate.NullRequestLimitsHandler()
if s, ok := a.delegate.(*consul.Server); ok { if s, ok := a.delegate.(*consul.Server); ok {
grpcRateLimiter = s.IncomingRPCLimiter() grpcRateLimiter = s.IncomingRPCLimiter()
} }

View File

@ -175,11 +175,15 @@ func (h *Handler) Run(ctx context.Context) {
// Allow returns an error if the given operation is not allowed to proceed // Allow returns an error if the given operation is not allowed to proceed
// because of an exhausted rate-limit. // because of an exhausted rate-limit.
func (h *Handler) Allow(op Operation) error { func (h *Handler) Allow(op Operation) error {
// TODO(NET-1383): actually implement the rate limiting logic. cfg := h.cfg.Load()
// if cfg.GlobalMode == ModeDisabled {
// Example: return nil
// if !h.limiter.Allow(globalWrite) { }
// }
if !h.limiter.Allow(globalWrite) {
// TODO(NET-1383): actually implement the rate limiting logic and replace this returned nil.
return nil
}
return nil return nil
} }
@ -214,15 +218,15 @@ func (prefix globalLimit) Key() multilimiter.KeyType {
return multilimiter.Key(prefix, nil) return multilimiter.Key(prefix, nil)
} }
// NullRateLimiter returns a RateLimiter that allows every operation. // NullRequestLimitsHandler returns a RequestLimitsHandler that allows every operation.
func NullRateLimiter() RequestLimitsHandler { func NullRequestLimitsHandler() RequestLimitsHandler {
return nullRateLimiter{} return nullRequestLimitsHandler{}
} }
type nullRateLimiter struct{} type nullRequestLimitsHandler struct{}
func (nullRateLimiter) Allow(Operation) error { return nil } func (nullRequestLimitsHandler) Allow(Operation) error { return nil }
func (nullRateLimiter) Run(ctx context.Context) {} func (nullRequestLimitsHandler) Run(ctx context.Context) {}
func (nullRateLimiter) UpdateConfig(cfg HandlerConfig) {} func (nullRequestLimitsHandler) UpdateConfig(cfg HandlerConfig) {}

View File

@ -1,6 +1,8 @@
package rate package rate
import ( import (
"net"
"net/netip"
"testing" "testing"
"github.com/hashicorp/consul/agent/consul/multilimiter" "github.com/hashicorp/consul/agent/consul/multilimiter"
@ -18,6 +20,7 @@ func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) {
GlobalWriteConfig: writeCfg, GlobalWriteConfig: writeCfg,
GlobalMode: ModeEnforcing, GlobalMode: ModeEnforcing,
} }
logger := hclog.NewNullLogger() logger := hclog.NewNullLogger()
NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger) NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger)
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2) mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2)
@ -88,3 +91,59 @@ func TestUpdateConfig(t *testing.T) {
}) })
} }
} }
func TestAllow(t *testing.T) {
readCfg := multilimiter.LimiterConfig{Rate: 100, Burst: 100}
writeCfg := multilimiter.LimiterConfig{Rate: 99, Burst: 99}
type testCase struct {
description string
cfg *HandlerConfig
expectedAllowCalls int
}
testCases := []testCase{
{
description: "RateLimiter does not get called when mode is disabled.",
cfg: &HandlerConfig{
GlobalReadConfig: readCfg,
GlobalWriteConfig: writeCfg,
GlobalMode: ModeDisabled,
},
expectedAllowCalls: 0,
},
{
description: "RateLimiter gets called when mode is permissive.",
cfg: &HandlerConfig{
GlobalReadConfig: readCfg,
GlobalWriteConfig: writeCfg,
GlobalMode: ModePermissive,
},
expectedAllowCalls: 1,
},
{
description: "RateLimiter gets called when mode is enforcing.",
cfg: &HandlerConfig{
GlobalReadConfig: readCfg,
GlobalWriteConfig: writeCfg,
GlobalMode: ModeEnforcing,
},
expectedAllowCalls: 1,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
if tc.expectedAllowCalls > 0 {
mockRateLimiter.On("Allow", mock.Anything).Return(func(entity multilimiter.LimitedEntity) bool { return true })
}
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
logger := hclog.NewNullLogger()
handler := NewHandlerWithLimiter(*tc.cfg, nil, mockRateLimiter, logger)
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))
mockRateLimiter.Calls = nil
handler.Allow(Operation{Name: "test", SourceAddr: addr})
mockRateLimiter.AssertNumberOfCalls(t, "Allow", tc.expectedAllowCalls)
})
}
}

View File

@ -333,7 +333,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) {
oldNotify() oldNotify()
} }
} }
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, rpcRate.NullRateLimiter()) grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, rpcRate.NullRequestLimitsHandler())
srv, err := NewServer(c, deps, grpcServer) srv, err := NewServer(c, deps, grpcServer)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -24,7 +24,7 @@ import (
func TestServer_EmitsStats(t *testing.T) { func TestServer_EmitsStats(t *testing.T) {
sink, metricsObj := testutil.NewFakeSink(t) sink, metricsObj := testutil.NewFakeSink(t)
srv := NewServer(hclog.Default(), metricsObj, nil, rate.NullRateLimiter()) srv := NewServer(hclog.Default(), metricsObj, nil, rate.NullRequestLimitsHandler())
testservice.RegisterSimpleServer(srv, &testservice.Simple{}) testservice.RegisterSimpleServer(srv, &testservice.Simple{})

View File

@ -55,7 +55,7 @@ func newPanicTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsC
func newTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator, register func(server *grpc.Server)) testServer { func newTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator, register func(server *grpc.Server)) testServer {
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
handler := NewHandler(logger, addr, register, nil, rate.NullRateLimiter()) handler := NewHandler(logger, addr, register, nil, rate.NullRequestLimitsHandler())
lis, err := net.Listen("tcp", "127.0.0.1:0") lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err) require.NoError(t, err)

View File

@ -381,7 +381,7 @@ func runTestServer(t *testing.T, server *Server) net.Addr {
pbsubscribe.RegisterStateChangeSubscriptionServer(srv, server) pbsubscribe.RegisterStateChangeSubscriptionServer(srv, server)
}, },
nil, nil,
rate.NullRateLimiter(), rate.NullRequestLimitsHandler(),
) )
lis, err := net.Listen("tcp", "127.0.0.1:0") lis, err := net.Listen("tcp", "127.0.0.1:0")

View File

@ -26,7 +26,7 @@ func TestHandler_EmitsStats(t *testing.T) {
sink, metricsObj := testutil.NewFakeSink(t) sink, metricsObj := testutil.NewFakeSink(t)
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
handler := NewHandler(hclog.Default(), addr, noopRegister, metricsObj, rate.NullRateLimiter()) handler := NewHandler(hclog.Default(), addr, noopRegister, metricsObj, rate.NullRequestLimitsHandler())
testservice.RegisterSimpleServer(handler.srv, &testservice.Simple{}) testservice.RegisterSimpleServer(handler.srv, &testservice.Simple{})

View File

@ -1592,7 +1592,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer {
conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta() conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta()
deps := newDefaultDeps(t, conf) deps := newDefaultDeps(t, conf)
externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, rate.NullRateLimiter()) externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, rate.NullRequestLimitsHandler())
server, err := consul.NewServer(conf, deps, externalGRPCServer) server, err := consul.NewServer(conf, deps, externalGRPCServer)
require.NoError(t, err) require.NoError(t, err)