From a6482341a59c47df7566ec84031d39cf7a158784 Mon Sep 17 00:00:00 2001 From: Semir Patel Date: Wed, 4 Jan 2023 13:38:44 -0600 Subject: [PATCH] Wire up the rate limiter to net/rpc calls (#15879) --- agent/consul/rate/handler.go | 3 +++ .../consul/rate/mock_RequestLimitsHandler.go | 5 +++++ agent/consul/server.go | 20 +++++++++++-------- agent/rpc/middleware/interceptors.go | 18 +++++++++++++++++ 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/agent/consul/rate/handler.go b/agent/consul/rate/handler.go index 2e0708c767..bfdb587f4f 100644 --- a/agent/consul/rate/handler.go +++ b/agent/consul/rate/handler.go @@ -109,6 +109,7 @@ type RequestLimitsHandler interface { Run(ctx context.Context) Allow(op Operation) error UpdateConfig(cfg HandlerConfig) + Register(leaderStatusProvider LeaderStatusProvider) } // Handler enforces rate limits for incoming RPCs. @@ -310,3 +311,5 @@ func (nullRequestLimitsHandler) Allow(Operation) error { return nil } func (nullRequestLimitsHandler) Run(ctx context.Context) {} func (nullRequestLimitsHandler) UpdateConfig(cfg HandlerConfig) {} + +func (nullRequestLimitsHandler) Register(leaderStatusProvider LeaderStatusProvider) {} diff --git a/agent/consul/rate/mock_RequestLimitsHandler.go b/agent/consul/rate/mock_RequestLimitsHandler.go index 02569e56e5..9ff0e3baaf 100644 --- a/agent/consul/rate/mock_RequestLimitsHandler.go +++ b/agent/consul/rate/mock_RequestLimitsHandler.go @@ -37,6 +37,11 @@ func (_m *MockRequestLimitsHandler) UpdateConfig(cfg HandlerConfig) { _m.Called(cfg) } +// Register provides a mock function with given fields: leaderStatusProvider +func (_m *MockRequestLimitsHandler) Register(leaderStatusProvider LeaderStatusProvider) { + _m.Called(leaderStatusProvider) +} + type mockConstructorTestingTNewMockRequestLimitsHandler interface { mock.TestingT Cleanup(func()) diff --git a/agent/consul/server.go b/agent/consul/server.go index 8580de9ba4..b67e59c108 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -469,14 +469,14 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom incomingRPCLimiter: incomingRPCLimiter, } + incomingRPCLimiter.Register(s) + s.hcpManager = hcp.NewManager(hcp.ManagerConfig{ Client: flat.HCP.Client, StatusFn: s.hcpServerStatus(flat), Logger: logger.Named("hcp_manager"), }) - s.incomingRPCLimiter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh}) - var recorder *middleware.RequestRecorder if flat.NewRequestRecorderFunc != nil { recorder = flat.NewRequestRecorderFunc(serverLogger, s.IsLeader, s.config.Datacenter) @@ -487,15 +487,19 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom return nil, fmt.Errorf("cannot initialize server with a nil RPC request recorder") } - if flat.GetNetRPCInterceptorFunc == nil { - s.rpcServer = rpc.NewServer() - s.insecureRPCServer = rpc.NewServer() - } else { - s.rpcServer = rpc.NewServerWithOpts(rpc.WithServerServiceCallInterceptor(flat.GetNetRPCInterceptorFunc(recorder))) - s.insecureRPCServer = rpc.NewServerWithOpts(rpc.WithServerServiceCallInterceptor(flat.GetNetRPCInterceptorFunc(recorder))) + rpcServerOpts := []func(*rpc.Server){ + rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter)), } + if flat.GetNetRPCInterceptorFunc != nil { + rpcServerOpts = append(rpcServerOpts, rpc.WithServerServiceCallInterceptor(flat.GetNetRPCInterceptorFunc(recorder))) + } + + s.rpcServer = rpc.NewServerWithOpts(rpcServerOpts...) + s.insecureRPCServer = rpc.NewServerWithOpts(rpcServerOpts...) + s.rpcRecorder = recorder + s.incomingRPCLimiter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh}) go s.publisher.Run(&lib.StopChannelContext{StopCh: s.shutdownCh}) diff --git a/agent/rpc/middleware/interceptors.go b/agent/rpc/middleware/interceptors.go index 6abcf0a443..a4aa432d62 100644 --- a/agent/rpc/middleware/interceptors.go +++ b/agent/rpc/middleware/interceptors.go @@ -1,6 +1,7 @@ package middleware import ( + "net" "reflect" "strconv" "strings" @@ -9,6 +10,7 @@ import ( "github.com/armon/go-metrics" "github.com/armon/go-metrics/prometheus" "github.com/hashicorp/consul-net-rpc/net/rpc" + rpcRate "github.com/hashicorp/consul/agent/consul/rate" "github.com/hashicorp/go-hclog" ) @@ -157,3 +159,19 @@ func GetNetRPCInterceptor(recorder *RequestRecorder) rpc.ServerServiceCallInterc recorder.Record(reqServiceMethod, RPCTypeNetRPC, reqStart, argv.Interface(), err != nil) } } + +func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler) rpc.PreBodyInterceptor { + + return func(reqServiceMethod string, sourceAddr net.Addr) error { + op := rpcRate.Operation{ + Name: reqServiceMethod, + SourceAddr: sourceAddr, + Type: rpcRateLimitSpecs[reqServiceMethod], + } + + // net/rpc does not provide a way to encode the nuances of the + // error response (retry or retry elsewhere) so the error string + // from the rate limiter is all that we have. + return requestLimitsHandler.Allow(op) + } +}