mirror of
https://github.com/status-im/consul.git
synced 2025-01-11 06:16:08 +00:00
Add rate limiting to RPCs sent within a server instance too (#5927)
This commit is contained in:
parent
3517e47ad1
commit
acfcc7daf4
@ -18,6 +18,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
ca "github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/consul/agent/consul/fsm"
|
||||
@ -34,6 +35,7 @@ import (
|
||||
"github.com/hashicorp/raft"
|
||||
raftboltdb "github.com/hashicorp/raft-boltdb"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// These are the protocol versions that Consul can _understand_. These are
|
||||
@ -206,6 +208,10 @@ type Server struct {
|
||||
// Enterprise user-defined areas.
|
||||
router *router.Router
|
||||
|
||||
// rpcLimiter is used to rate limit the total number of RPCs initiated
|
||||
// from an agent.
|
||||
rpcLimiter atomic.Value
|
||||
|
||||
// Listener is used to listen for incoming connections
|
||||
Listener net.Listener
|
||||
rpcServer *rpc.Server
|
||||
@ -360,6 +366,8 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store, tl
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst))
|
||||
|
||||
configReplicatorConfig := ReplicatorConfig{
|
||||
Name: "Config Entry",
|
||||
ReplicateFn: s.replicateConfig,
|
||||
@ -1028,6 +1036,19 @@ func (s *Server) RPC(method string, args interface{}, reply interface{}) error {
|
||||
args: args,
|
||||
reply: reply,
|
||||
}
|
||||
|
||||
// Enforce the RPC limit.
|
||||
//
|
||||
// "client" metric path because the internal client API is calling to the
|
||||
// internal server API. It's odd that the same request directed to a server is
|
||||
// recorded differently. On the other hand this possibly masks the different
|
||||
// between regular client requests that traverse the network and these which
|
||||
// don't (unless forwarded). This still seems most sane.
|
||||
metrics.IncrCounter([]string{"client", "rpc"}, 1)
|
||||
if !s.rpcLimiter.Load().(*rate.Limiter).Allow() {
|
||||
metrics.IncrCounter([]string{"client", "rpc", "exceeded"}, 1)
|
||||
return structs.ErrRPCRateExceeded
|
||||
}
|
||||
if err := s.rpcServer.ServeRequest(codec); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1039,6 +1060,19 @@ func (s *Server) RPC(method string, args interface{}, reply interface{}) error {
|
||||
func (s *Server) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer,
|
||||
replyFn structs.SnapshotReplyFn) error {
|
||||
|
||||
// Enforce the RPC limit.
|
||||
//
|
||||
// "client" metric path because the internal client API is calling to the
|
||||
// internal server API. It's odd that the same request directed to a server is
|
||||
// recorded differently. On the other hand this possibly masks the different
|
||||
// between regular client requests that traverse the network and these which
|
||||
// don't (unless forwarded). This still seems most sane.
|
||||
metrics.IncrCounter([]string{"client", "rpc"}, 1)
|
||||
if !s.rpcLimiter.Load().(*rate.Limiter).Allow() {
|
||||
metrics.IncrCounter([]string{"client", "rpc", "exceeded"}, 1)
|
||||
return structs.ErrRPCRateExceeded
|
||||
}
|
||||
|
||||
// Perform the operation.
|
||||
var reply structs.SnapshotResponse
|
||||
snap, err := s.dispatchSnapshotRequest(args, in, &reply)
|
||||
@ -1141,6 +1175,8 @@ func (s *Server) GetLANCoordinate() (lib.CoordinateSet, error) {
|
||||
// ReloadConfig is used to have the Server do an online reload of
|
||||
// relevant configuration information
|
||||
func (s *Server) ReloadConfig(config *Config) error {
|
||||
s.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst))
|
||||
|
||||
if s.IsLeader() {
|
||||
// only bootstrap the config entries if we are the leader
|
||||
// this will error if we lose leadership while bootstrapping here.
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -988,6 +989,8 @@ func TestServer_Reload(t *testing.T) {
|
||||
|
||||
dir1, s := testServerWithConfig(t, func(c *Config) {
|
||||
c.Build = "1.5.0"
|
||||
c.RPCRate = 500
|
||||
c.RPCMaxBurst = 5000
|
||||
})
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s.Shutdown()
|
||||
@ -998,6 +1001,14 @@ func TestServer_Reload(t *testing.T) {
|
||||
global_entry_init,
|
||||
}
|
||||
|
||||
limiter := s.rpcLimiter.Load().(*rate.Limiter)
|
||||
require.Equal(t, rate.Limit(500), limiter.Limit())
|
||||
require.Equal(t, 5000, limiter.Burst())
|
||||
|
||||
// Change rate limit
|
||||
s.config.RPCRate = 1000
|
||||
s.config.RPCMaxBurst = 10000
|
||||
|
||||
s.ReloadConfig(s.config)
|
||||
|
||||
_, entry, err := s.fsm.State().ConfigEntry(nil, structs.ProxyDefaults, structs.ProxyConfigGlobal)
|
||||
@ -1008,4 +1019,30 @@ func TestServer_Reload(t *testing.T) {
|
||||
require.Equal(t, global_entry_init.Kind, global.Kind)
|
||||
require.Equal(t, global_entry_init.Name, global.Name)
|
||||
require.Equal(t, global_entry_init.Config, global.Config)
|
||||
|
||||
// Check rate limiter got updated
|
||||
limiter = s.rpcLimiter.Load().(*rate.Limiter)
|
||||
require.Equal(t, rate.Limit(1000), limiter.Limit())
|
||||
require.Equal(t, 10000, limiter.Burst())
|
||||
}
|
||||
|
||||
func TestServer_RPC_RateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir1, conf1 := testServerConfig(t)
|
||||
conf1.RPCRate = 2
|
||||
conf1.RPCMaxBurst = 2
|
||||
s1, err := NewServer(conf1)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
retry.Run(t, func(r *retry.R) {
|
||||
var out struct{}
|
||||
if err := s1.RPC("Status.Ping", struct{}{}, &out); err != structs.ErrRPCRateExceeded {
|
||||
r.Fatalf("err: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user