mirror of https://github.com/status-im/consul.git
Rate limiting handler - ensure configuration has changed before modifying limiters (#15805)
* Rate limiting handler - ensure configuration has changed before modifying limiters * Updating test to validate arguments to UpdateConfig * Removing duplicate test. Updating mock. * adding logging for when UpdateConfig is called but the config has not changed. * Update agent/consul/rate/handler.go Co-authored-by: Dhia Ayachi <dhia@hashicorp.com> Co-authored-by: Dhia Ayachi <dhia@hashicorp.com>
This commit is contained in:
parent
629878a687
commit
aba43d85d9
|
@ -0,0 +1,53 @@
|
|||
// Code generated by mockery v2.15.0. DO NOT EDIT.
|
||||
|
||||
package multilimiter
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockRateLimiter is an autogenerated mock type for the RateLimiter type
|
||||
type MockRateLimiter struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Allow provides a mock function with given fields: entity
|
||||
func (_m *MockRateLimiter) Allow(entity LimitedEntity) bool {
|
||||
ret := _m.Called(entity)
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func(LimitedEntity) bool); ok {
|
||||
r0 = rf(entity)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Run provides a mock function with given fields: ctx
|
||||
func (_m *MockRateLimiter) Run(ctx context.Context) {
|
||||
_m.Called(ctx)
|
||||
}
|
||||
|
||||
// UpdateConfig provides a mock function with given fields: c, prefix
|
||||
func (_m *MockRateLimiter) UpdateConfig(c LimiterConfig, prefix []byte) {
|
||||
_m.Called(c, prefix)
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockRateLimiter interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// NewMockRateLimiter creates a new instance of MockRateLimiter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockRateLimiter(t mockConstructorTestingTNewMockRateLimiter) *MockRateLimiter {
|
||||
mock := &MockRateLimiter{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -3,11 +3,12 @@ package multilimiter
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
radix "github.com/hashicorp/go-immutable-radix"
|
||||
"golang.org/x/time/rate"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
radix "github.com/hashicorp/go-immutable-radix"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
var _ RateLimiter = &MultiLimiter{}
|
||||
|
@ -23,6 +24,8 @@ func Key(prefix, key []byte) KeyType {
|
|||
}
|
||||
|
||||
// RateLimiter is the interface implemented by MultiLimiter
|
||||
//
|
||||
//go:generate mockery --name RateLimiter --inpackage --filename mock_RateLimiter.go
|
||||
type RateLimiter interface {
|
||||
Run(ctx context.Context)
|
||||
Allow(entity LimitedEntity) bool
|
||||
|
|
|
@ -5,9 +5,11 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/multilimiter"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -114,6 +116,7 @@ type Handler struct {
|
|||
delegate HandlerDelegate
|
||||
|
||||
limiter multilimiter.RateLimiter
|
||||
logger hclog.Logger
|
||||
}
|
||||
|
||||
type HandlerConfig struct {
|
||||
|
@ -140,9 +143,8 @@ type HandlerDelegate interface {
|
|||
IsLeader() bool
|
||||
}
|
||||
|
||||
// NewHandler creates a new RPC rate limit handler.
|
||||
func NewHandler(cfg HandlerConfig, delegate HandlerDelegate) *Handler {
|
||||
limiter := multilimiter.NewMultiLimiter(cfg.Config)
|
||||
func NewHandlerWithLimiter(cfg HandlerConfig, delegate HandlerDelegate,
|
||||
limiter multilimiter.RateLimiter, logger hclog.Logger) *Handler {
|
||||
limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite)
|
||||
limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead)
|
||||
|
||||
|
@ -150,12 +152,19 @@ func NewHandler(cfg HandlerConfig, delegate HandlerDelegate) *Handler {
|
|||
cfg: new(atomic.Pointer[HandlerConfig]),
|
||||
delegate: delegate,
|
||||
limiter: limiter,
|
||||
logger: logger,
|
||||
}
|
||||
h.cfg.Store(&cfg)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// NewHandler creates a new RPC rate limit handler.
|
||||
func NewHandler(cfg HandlerConfig, delegate HandlerDelegate, logger hclog.Logger) *Handler {
|
||||
limiter := multilimiter.NewMultiLimiter(cfg.Config)
|
||||
return NewHandlerWithLimiter(cfg, delegate, limiter, logger)
|
||||
}
|
||||
|
||||
// Run the limiter cleanup routine until the given context is canceled.
|
||||
//
|
||||
// Note: this starts a goroutine.
|
||||
|
@ -175,9 +184,18 @@ func (h *Handler) Allow(op Operation) error {
|
|||
}
|
||||
|
||||
func (h *Handler) UpdateConfig(cfg HandlerConfig) {
|
||||
existingCfg := h.cfg.Load()
|
||||
h.cfg.Store(&cfg)
|
||||
h.limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite)
|
||||
h.limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead)
|
||||
if reflect.DeepEqual(existingCfg, cfg) {
|
||||
h.logger.Warn("UpdateConfig called but configuration has not changed. Skipping updating the server rate limiter configuration.")
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(existingCfg.GlobalWriteConfig, cfg.GlobalWriteConfig) {
|
||||
h.limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite)
|
||||
}
|
||||
if !reflect.DeepEqual(existingCfg.GlobalReadConfig, cfg.GlobalReadConfig) {
|
||||
h.limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
package rate
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/multilimiter"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) {
|
||||
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
|
||||
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||
readCfg := multilimiter.LimiterConfig{Rate: 100, Burst: 100}
|
||||
writeCfg := multilimiter.LimiterConfig{Rate: 99, Burst: 99}
|
||||
cfg := &HandlerConfig{
|
||||
GlobalReadConfig: readCfg,
|
||||
GlobalWriteConfig: writeCfg,
|
||||
GlobalMode: ModeEnforcing,
|
||||
}
|
||||
logger := hclog.NewNullLogger()
|
||||
NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger)
|
||||
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2)
|
||||
}
|
||||
|
||||
func TestUpdateConfig(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
configModFunc func(cfg *HandlerConfig)
|
||||
assertFunc func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig)
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
description: "RateLimiter does not get updated when config does not change.",
|
||||
configModFunc: func(cfg *HandlerConfig) {},
|
||||
assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) {
|
||||
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 0)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "RateLimiter gets updated when GlobalReadConfig changes.",
|
||||
configModFunc: func(cfg *HandlerConfig) {
|
||||
cfg.GlobalReadConfig.Burst++
|
||||
},
|
||||
assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) {
|
||||
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 1)
|
||||
mockRateLimiter.AssertCalled(t, "UpdateConfig", cfg.GlobalReadConfig, []byte("global.read"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "RateLimiter gets updated when GlobalWriteConfig changes.",
|
||||
configModFunc: func(cfg *HandlerConfig) {
|
||||
cfg.GlobalWriteConfig.Burst++
|
||||
},
|
||||
assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) {
|
||||
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 1)
|
||||
mockRateLimiter.AssertCalled(t, "UpdateConfig", cfg.GlobalWriteConfig, []byte("global.write"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "RateLimiter does not get updated when GlobalMode changes.",
|
||||
configModFunc: func(cfg *HandlerConfig) {
|
||||
cfg.GlobalMode = ModePermissive
|
||||
},
|
||||
assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) {
|
||||
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 0)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
readCfg := multilimiter.LimiterConfig{Rate: 100, Burst: 100}
|
||||
writeCfg := multilimiter.LimiterConfig{Rate: 99, Burst: 99}
|
||||
cfg := &HandlerConfig{
|
||||
GlobalReadConfig: readCfg,
|
||||
GlobalWriteConfig: writeCfg,
|
||||
GlobalMode: ModeEnforcing,
|
||||
}
|
||||
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
|
||||
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||
logger := hclog.NewNullLogger()
|
||||
handler := NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger)
|
||||
mockRateLimiter.Calls = nil
|
||||
tc.configModFunc(cfg)
|
||||
handler.UpdateConfig(*cfg)
|
||||
tc.assertFunc(mockRateLimiter, cfg)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -480,7 +480,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
|
|||
WriteRate: config.RequestLimitsWriteRate,
|
||||
}
|
||||
|
||||
s.incomingRPCLimiter = rpcRate.NewHandler(*s.convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg), s)
|
||||
s.incomingRPCLimiter = rpcRate.NewHandler(*s.convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg), s, s.logger)
|
||||
}
|
||||
s.incomingRPCLimiter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh})
|
||||
|
||||
|
|
Loading…
Reference in New Issue