add necessary plumbing to implement per server ip based rate limiting (#17436)

This commit is contained in:
Dhia Ayachi 2023-05-23 15:37:01 -04:00 committed by GitHub
parent 304d641fb1
commit f526dfd0ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 113 additions and 81 deletions

View File

@ -8,6 +8,7 @@ import (
"context"
"errors"
"fmt"
"github.com/hashicorp/consul/agent/metadata"
"net"
"reflect"
"sync/atomic"
@ -153,14 +154,14 @@ type RequestLimitsHandler interface {
Allow(op Operation) error
UpdateConfig(cfg HandlerConfig)
UpdateIPConfig(cfg IPLimitConfig)
Register(leaderStatusProvider LeaderStatusProvider)
Register(serversStatusProvider ServersStatusProvider)
}
// Handler enforces rate limits for incoming RPCs.
type Handler struct {
globalCfg *atomic.Pointer[HandlerConfig]
ipCfg *atomic.Pointer[IPLimitConfig]
leaderStatusProvider LeaderStatusProvider
globalCfg *atomic.Pointer[HandlerConfig]
ipCfg *atomic.Pointer[IPLimitConfig]
serversStatusProvider ServersStatusProvider
limiter multilimiter.RateLimiter
@ -186,13 +187,14 @@ type HandlerConfig struct {
GlobalLimitConfig GlobalLimitConfig
}
//go:generate mockery --name LeaderStatusProvider --inpackage --filename mock_LeaderStatusProvider_test.go
type LeaderStatusProvider interface {
//go:generate mockery --name ServersStatusProvider --inpackage --filename mock_ServersStatusProvider_test.go
type ServersStatusProvider interface {
// IsLeader is used to determine whether the operation is being performed
// against the cluster leader, such that if it can _only_ be performed by
// the leader (e.g. write operations) we don't tell clients to retry against
// a different server.
IsLeader() bool
IsServer(addr string) bool
}
func isInfRate(cfg multilimiter.LimiterConfig) bool {
@ -237,11 +239,11 @@ func (h *Handler) Run(ctx context.Context) {
// because of an exhausted rate-limit.
func (h *Handler) Allow(op Operation) error {
if h.leaderStatusProvider == nil {
h.logger.Error("leaderStatusProvider required to be set via Register(). bailing on rate limiter")
if h.serversStatusProvider == nil {
h.logger.Error("serversStatusProvider required to be set via Register(). bailing on rate limiter")
return nil
// TODO: panic and make sure to use the server's recovery handler
// panic("leaderStatusProvider required to be set via Register(..)")
// panic("serversStatusProvider required to be set via Register(..)")
}
cfg := h.globalCfg.Load()
@ -249,7 +251,7 @@ func (h *Handler) Allow(op Operation) error {
return nil
}
allow, throttledLimits := h.allowAllLimits(h.limits(op))
allow, throttledLimits := h.allowAllLimits(h.limits(op), h.serversStatusProvider.IsServer(string(metadata.GetIP(op.SourceAddr))))
if !allow {
for _, l := range throttledLimits {
@ -277,7 +279,7 @@ func (h *Handler) Allow(op Operation) error {
})
if enforced {
if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
if h.serversStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
return ErrRetryLater
}
return ErrRetryElsewhere
@ -305,17 +307,18 @@ func (h *Handler) UpdateConfig(cfg HandlerConfig) {
}
func (h *Handler) Register(leaderStatusProvider LeaderStatusProvider) {
h.leaderStatusProvider = leaderStatusProvider
func (h *Handler) Register(serversStatusProvider ServersStatusProvider) {
h.serversStatusProvider = serversStatusProvider
}
type limit struct {
mode Mode
ent multilimiter.LimitedEntity
desc string
mode Mode
ent multilimiter.LimitedEntity
desc string
applyOnServer bool
}
func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) {
func (h *Handler) allowAllLimits(limits []limit, isServer bool) (bool, []limit) {
allow := true
throttledLimits := make([]limit, 0)
@ -324,6 +327,10 @@ func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) {
continue
}
if isServer && !l.applyOnServer {
continue
}
if !h.limiter.Allow(l.ent) {
throttledLimits = append(throttledLimits, l)
allow = false
@ -358,7 +365,7 @@ func (h *Handler) globalLimit(op Operation) *limit {
}
cfg := h.globalCfg.Load()
lim := &limit{mode: cfg.GlobalLimitConfig.Mode}
lim := &limit{mode: cfg.GlobalLimitConfig.Mode, applyOnServer: true}
switch op.Type {
case OperationTypeRead:
lim.desc = "global/read"
@ -409,4 +416,4 @@ func (nullRequestLimitsHandler) Run(_ context.Context) {}
func (nullRequestLimitsHandler) UpdateConfig(_ HandlerConfig) {}
func (nullRequestLimitsHandler) Register(_ LeaderStatusProvider) {}
func (nullRequestLimitsHandler) Register(_ ServersStatusProvider) {}

View File

@ -19,22 +19,6 @@ import (
"github.com/hashicorp/consul/agent/consul/multilimiter"
)
//
// Revisit test when handler.go:189 TODO implemented
//
// func TestHandler_Allow_PanicsWhenLeaderStatusProviderNotRegistered(t *testing.T) {
// defer func() {
// err := recover()
// if err == nil {
// t.Fatal("Run should panic")
// }
// }()
// handler := NewHandler(HandlerConfig{}, hclog.NewNullLogger())
// handler.Allow(Operation{})
// // intentionally skip handler.Register(...)
// }
func TestHandler(t *testing.T) {
var (
rpcName = "Foo.Bar"
@ -50,6 +34,7 @@ func TestHandler(t *testing.T) {
globalMode Mode
checks []limitCheck
isLeader bool
isServer bool
expectErr error
expectLog bool
expectMetric bool
@ -230,8 +215,9 @@ func TestHandler(t *testing.T) {
limiter.On("Allow", mock.Anything).Return(c.allow)
}
leaderStatusProvider := NewMockLeaderStatusProvider(t)
leaderStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe()
serversStatusProvider := NewMockServersStatusProvider(t)
serversStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe()
serversStatusProvider.On("IsServer", mock.Anything).Return(tc.isServer).Maybe()
var output bytes.Buffer
logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{
@ -252,7 +238,7 @@ func TestHandler(t *testing.T) {
limiter,
logger,
)
handler.Register(leaderStatusProvider)
handler.Register(serversStatusProvider)
require.Equal(t, tc.expectErr, handler.Allow(tc.op))
@ -426,8 +412,9 @@ func TestAllow(t *testing.T) {
}
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
logger := hclog.NewNullLogger()
delegate := NewMockLeaderStatusProvider(t)
delegate := NewMockServersStatusProvider(t)
delegate.On("IsLeader").Return(true).Maybe()
delegate.On("IsServer", mock.Anything).Return(false).Maybe()
handler := NewHandlerWithLimiter(*tc.cfg, mockRateLimiter, logger)
handler.Register(delegate)
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))

View File

@ -1,39 +0,0 @@
// Code generated by mockery v2.20.0. DO NOT EDIT.
package rate
import mock "github.com/stretchr/testify/mock"
// MockLeaderStatusProvider is an autogenerated mock type for the LeaderStatusProvider type
type MockLeaderStatusProvider struct {
mock.Mock
}
// IsLeader provides a mock function with given fields:
func (_m *MockLeaderStatusProvider) IsLeader() bool {
ret := _m.Called()
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
type mockConstructorTestingTNewMockLeaderStatusProvider interface {
mock.TestingT
Cleanup(func())
}
// NewMockLeaderStatusProvider creates a new instance of MockLeaderStatusProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockLeaderStatusProvider(t mockConstructorTestingTNewMockLeaderStatusProvider) *MockLeaderStatusProvider {
mock := &MockLeaderStatusProvider{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -27,9 +27,9 @@ func (_m *MockRequestLimitsHandler) Allow(op Operation) error {
return r0
}
// Register provides a mock function with given fields: leaderStatusProvider
func (_m *MockRequestLimitsHandler) Register(leaderStatusProvider LeaderStatusProvider) {
_m.Called(leaderStatusProvider)
// Register provides a mock function with given fields: serversStatusProvider
func (_m *MockRequestLimitsHandler) Register(serversStatusProvider ServersStatusProvider) {
_m.Called(serversStatusProvider)
}
// Run provides a mock function with given fields: ctx

View File

@ -0,0 +1,53 @@
// Code generated by mockery v2.20.0. DO NOT EDIT.
package rate
import mock "github.com/stretchr/testify/mock"
// MockServersStatusProvider is an autogenerated mock type for the ServersStatusProvider type
type MockServersStatusProvider struct {
mock.Mock
}
// IsLeader provides a mock function with given fields:
func (_m *MockServersStatusProvider) IsLeader() bool {
ret := _m.Called()
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// IsServer provides a mock function with given fields: addr
func (_m *MockServersStatusProvider) IsServer(addr string) bool {
ret := _m.Called(addr)
var r0 bool
if rf, ok := ret.Get(0).(func(string) bool); ok {
r0 = rf(addr)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
type mockConstructorTestingTNewMockServersStatusProvider interface {
mock.TestingT
Cleanup(func())
}
// NewMockServersStatusProvider creates a new instance of MockServersStatusProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockServersStatusProvider(t mockConstructorTestingTNewMockServersStatusProvider) *MockServersStatusProvider {
mock := &MockServersStatusProvider{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -1660,6 +1660,20 @@ func (s *Server) IsLeader() bool {
return s.raft.State() == raft.Leader
}
// IsServer checks if this addr is of a server
func (s *Server) IsServer(addr string) bool {
for _, s := range s.raft.GetConfiguration().Configuration().Servers {
a, err := net.ResolveTCPAddr("tcp", string(s.Address))
if err != nil {
continue
}
if string(metadata.GetIP(a)) == addr {
return true
}
}
return false
}
// LeaderLastContact returns the time of last contact by a leader.
// This only makes sense if we are currently a follower.
func (s *Server) LeaderLastContact() time.Time {

View File

@ -221,3 +221,13 @@ func AddFeatureFlags(tags map[string]string, flags ...string) {
tags[featureFlagPrefix+flag] = "1"
}
}
func GetIP(addr net.Addr) []byte {
switch a := addr.(type) {
case *net.UDPAddr:
return []byte(a.IP.String())
case *net.TCPAddr:
return []byte(a.IP.String())
}
return []byte{}
}