mirror of https://github.com/status-im/consul.git
add necessary plumbing to implement per server ip based rate limiting (#17436)
This commit is contained in:
parent
304d641fb1
commit
f526dfd0ac
|
@ -8,6 +8,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
|
@ -153,14 +154,14 @@ type RequestLimitsHandler interface {
|
|||
Allow(op Operation) error
|
||||
UpdateConfig(cfg HandlerConfig)
|
||||
UpdateIPConfig(cfg IPLimitConfig)
|
||||
Register(leaderStatusProvider LeaderStatusProvider)
|
||||
Register(serversStatusProvider ServersStatusProvider)
|
||||
}
|
||||
|
||||
// Handler enforces rate limits for incoming RPCs.
|
||||
type Handler struct {
|
||||
globalCfg *atomic.Pointer[HandlerConfig]
|
||||
ipCfg *atomic.Pointer[IPLimitConfig]
|
||||
leaderStatusProvider LeaderStatusProvider
|
||||
globalCfg *atomic.Pointer[HandlerConfig]
|
||||
ipCfg *atomic.Pointer[IPLimitConfig]
|
||||
serversStatusProvider ServersStatusProvider
|
||||
|
||||
limiter multilimiter.RateLimiter
|
||||
|
||||
|
@ -186,13 +187,14 @@ type HandlerConfig struct {
|
|||
GlobalLimitConfig GlobalLimitConfig
|
||||
}
|
||||
|
||||
//go:generate mockery --name LeaderStatusProvider --inpackage --filename mock_LeaderStatusProvider_test.go
|
||||
type LeaderStatusProvider interface {
|
||||
//go:generate mockery --name ServersStatusProvider --inpackage --filename mock_ServersStatusProvider_test.go
|
||||
type ServersStatusProvider interface {
|
||||
// IsLeader is used to determine whether the operation is being performed
|
||||
// against the cluster leader, such that if it can _only_ be performed by
|
||||
// the leader (e.g. write operations) we don't tell clients to retry against
|
||||
// a different server.
|
||||
IsLeader() bool
|
||||
IsServer(addr string) bool
|
||||
}
|
||||
|
||||
func isInfRate(cfg multilimiter.LimiterConfig) bool {
|
||||
|
@ -237,11 +239,11 @@ func (h *Handler) Run(ctx context.Context) {
|
|||
// because of an exhausted rate-limit.
|
||||
func (h *Handler) Allow(op Operation) error {
|
||||
|
||||
if h.leaderStatusProvider == nil {
|
||||
h.logger.Error("leaderStatusProvider required to be set via Register(). bailing on rate limiter")
|
||||
if h.serversStatusProvider == nil {
|
||||
h.logger.Error("serversStatusProvider required to be set via Register(). bailing on rate limiter")
|
||||
return nil
|
||||
// TODO: panic and make sure to use the server's recovery handler
|
||||
// panic("leaderStatusProvider required to be set via Register(..)")
|
||||
// panic("serversStatusProvider required to be set via Register(..)")
|
||||
}
|
||||
|
||||
cfg := h.globalCfg.Load()
|
||||
|
@ -249,7 +251,7 @@ func (h *Handler) Allow(op Operation) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
allow, throttledLimits := h.allowAllLimits(h.limits(op))
|
||||
allow, throttledLimits := h.allowAllLimits(h.limits(op), h.serversStatusProvider.IsServer(string(metadata.GetIP(op.SourceAddr))))
|
||||
|
||||
if !allow {
|
||||
for _, l := range throttledLimits {
|
||||
|
@ -277,7 +279,7 @@ func (h *Handler) Allow(op Operation) error {
|
|||
})
|
||||
|
||||
if enforced {
|
||||
if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
|
||||
if h.serversStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
|
||||
return ErrRetryLater
|
||||
}
|
||||
return ErrRetryElsewhere
|
||||
|
@ -305,17 +307,18 @@ func (h *Handler) UpdateConfig(cfg HandlerConfig) {
|
|||
|
||||
}
|
||||
|
||||
func (h *Handler) Register(leaderStatusProvider LeaderStatusProvider) {
|
||||
h.leaderStatusProvider = leaderStatusProvider
|
||||
func (h *Handler) Register(serversStatusProvider ServersStatusProvider) {
|
||||
h.serversStatusProvider = serversStatusProvider
|
||||
}
|
||||
|
||||
type limit struct {
|
||||
mode Mode
|
||||
ent multilimiter.LimitedEntity
|
||||
desc string
|
||||
mode Mode
|
||||
ent multilimiter.LimitedEntity
|
||||
desc string
|
||||
applyOnServer bool
|
||||
}
|
||||
|
||||
func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) {
|
||||
func (h *Handler) allowAllLimits(limits []limit, isServer bool) (bool, []limit) {
|
||||
allow := true
|
||||
throttledLimits := make([]limit, 0)
|
||||
|
||||
|
@ -324,6 +327,10 @@ func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) {
|
|||
continue
|
||||
}
|
||||
|
||||
if isServer && !l.applyOnServer {
|
||||
continue
|
||||
}
|
||||
|
||||
if !h.limiter.Allow(l.ent) {
|
||||
throttledLimits = append(throttledLimits, l)
|
||||
allow = false
|
||||
|
@ -358,7 +365,7 @@ func (h *Handler) globalLimit(op Operation) *limit {
|
|||
}
|
||||
cfg := h.globalCfg.Load()
|
||||
|
||||
lim := &limit{mode: cfg.GlobalLimitConfig.Mode}
|
||||
lim := &limit{mode: cfg.GlobalLimitConfig.Mode, applyOnServer: true}
|
||||
switch op.Type {
|
||||
case OperationTypeRead:
|
||||
lim.desc = "global/read"
|
||||
|
@ -409,4 +416,4 @@ func (nullRequestLimitsHandler) Run(_ context.Context) {}
|
|||
|
||||
func (nullRequestLimitsHandler) UpdateConfig(_ HandlerConfig) {}
|
||||
|
||||
func (nullRequestLimitsHandler) Register(_ LeaderStatusProvider) {}
|
||||
func (nullRequestLimitsHandler) Register(_ ServersStatusProvider) {}
|
||||
|
|
|
@ -19,22 +19,6 @@ import (
|
|||
"github.com/hashicorp/consul/agent/consul/multilimiter"
|
||||
)
|
||||
|
||||
//
|
||||
// Revisit test when handler.go:189 TODO implemented
|
||||
//
|
||||
// func TestHandler_Allow_PanicsWhenLeaderStatusProviderNotRegistered(t *testing.T) {
|
||||
// defer func() {
|
||||
// err := recover()
|
||||
// if err == nil {
|
||||
// t.Fatal("Run should panic")
|
||||
// }
|
||||
// }()
|
||||
|
||||
// handler := NewHandler(HandlerConfig{}, hclog.NewNullLogger())
|
||||
// handler.Allow(Operation{})
|
||||
// // intentionally skip handler.Register(...)
|
||||
// }
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
var (
|
||||
rpcName = "Foo.Bar"
|
||||
|
@ -50,6 +34,7 @@ func TestHandler(t *testing.T) {
|
|||
globalMode Mode
|
||||
checks []limitCheck
|
||||
isLeader bool
|
||||
isServer bool
|
||||
expectErr error
|
||||
expectLog bool
|
||||
expectMetric bool
|
||||
|
@ -230,8 +215,9 @@ func TestHandler(t *testing.T) {
|
|||
limiter.On("Allow", mock.Anything).Return(c.allow)
|
||||
}
|
||||
|
||||
leaderStatusProvider := NewMockLeaderStatusProvider(t)
|
||||
leaderStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe()
|
||||
serversStatusProvider := NewMockServersStatusProvider(t)
|
||||
serversStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe()
|
||||
serversStatusProvider.On("IsServer", mock.Anything).Return(tc.isServer).Maybe()
|
||||
|
||||
var output bytes.Buffer
|
||||
logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{
|
||||
|
@ -252,7 +238,7 @@ func TestHandler(t *testing.T) {
|
|||
limiter,
|
||||
logger,
|
||||
)
|
||||
handler.Register(leaderStatusProvider)
|
||||
handler.Register(serversStatusProvider)
|
||||
|
||||
require.Equal(t, tc.expectErr, handler.Allow(tc.op))
|
||||
|
||||
|
@ -426,8 +412,9 @@ func TestAllow(t *testing.T) {
|
|||
}
|
||||
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||
logger := hclog.NewNullLogger()
|
||||
delegate := NewMockLeaderStatusProvider(t)
|
||||
delegate := NewMockServersStatusProvider(t)
|
||||
delegate.On("IsLeader").Return(true).Maybe()
|
||||
delegate.On("IsServer", mock.Anything).Return(false).Maybe()
|
||||
handler := NewHandlerWithLimiter(*tc.cfg, mockRateLimiter, logger)
|
||||
handler.Register(delegate)
|
||||
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
// Code generated by mockery v2.20.0. DO NOT EDIT.
|
||||
|
||||
package rate
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockLeaderStatusProvider is an autogenerated mock type for the LeaderStatusProvider type
|
||||
type MockLeaderStatusProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// IsLeader provides a mock function with given fields:
|
||||
func (_m *MockLeaderStatusProvider) IsLeader() bool {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockLeaderStatusProvider interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// NewMockLeaderStatusProvider creates a new instance of MockLeaderStatusProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockLeaderStatusProvider(t mockConstructorTestingTNewMockLeaderStatusProvider) *MockLeaderStatusProvider {
|
||||
mock := &MockLeaderStatusProvider{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -27,9 +27,9 @@ func (_m *MockRequestLimitsHandler) Allow(op Operation) error {
|
|||
return r0
|
||||
}
|
||||
|
||||
// Register provides a mock function with given fields: leaderStatusProvider
|
||||
func (_m *MockRequestLimitsHandler) Register(leaderStatusProvider LeaderStatusProvider) {
|
||||
_m.Called(leaderStatusProvider)
|
||||
// Register provides a mock function with given fields: serversStatusProvider
|
||||
func (_m *MockRequestLimitsHandler) Register(serversStatusProvider ServersStatusProvider) {
|
||||
_m.Called(serversStatusProvider)
|
||||
}
|
||||
|
||||
// Run provides a mock function with given fields: ctx
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
// Code generated by mockery v2.20.0. DO NOT EDIT.
|
||||
|
||||
package rate
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockServersStatusProvider is an autogenerated mock type for the ServersStatusProvider type
|
||||
type MockServersStatusProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// IsLeader provides a mock function with given fields:
|
||||
func (_m *MockServersStatusProvider) IsLeader() bool {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// IsServer provides a mock function with given fields: addr
|
||||
func (_m *MockServersStatusProvider) IsServer(addr string) bool {
|
||||
ret := _m.Called(addr)
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func(string) bool); ok {
|
||||
r0 = rf(addr)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockServersStatusProvider interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// NewMockServersStatusProvider creates a new instance of MockServersStatusProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockServersStatusProvider(t mockConstructorTestingTNewMockServersStatusProvider) *MockServersStatusProvider {
|
||||
mock := &MockServersStatusProvider{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -1660,6 +1660,20 @@ func (s *Server) IsLeader() bool {
|
|||
return s.raft.State() == raft.Leader
|
||||
}
|
||||
|
||||
// IsServer checks if this addr is of a server
|
||||
func (s *Server) IsServer(addr string) bool {
|
||||
for _, s := range s.raft.GetConfiguration().Configuration().Servers {
|
||||
a, err := net.ResolveTCPAddr("tcp", string(s.Address))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if string(metadata.GetIP(a)) == addr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// LeaderLastContact returns the time of last contact by a leader.
|
||||
// This only makes sense if we are currently a follower.
|
||||
func (s *Server) LeaderLastContact() time.Time {
|
||||
|
|
|
@ -221,3 +221,13 @@ func AddFeatureFlags(tags map[string]string, flags ...string) {
|
|||
tags[featureFlagPrefix+flag] = "1"
|
||||
}
|
||||
}
|
||||
|
||||
func GetIP(addr net.Addr) []byte {
|
||||
switch a := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
return []byte(a.IP.String())
|
||||
case *net.TCPAddr:
|
||||
return []byte(a.IP.String())
|
||||
}
|
||||
return []byte{}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue