diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index 06ef80efeb..d593f5aa9c 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -21,6 +21,7 @@ import ( "github.com/hashicorp/consul/agent/grpc/private/resolver" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/router" + "github.com/hashicorp/consul/agent/rpc/middleware" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/sdk/freeport" @@ -542,8 +543,10 @@ func newDefaultDeps(t *testing.T, c *Config) Deps { DialingFromServer: true, DialingFromDatacenter: c.Datacenter, }), - LeaderForwarder: builder, - EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c), + LeaderForwarder: builder, + NewRequestRecorderFunc: middleware.NewRequestRecorder, + GetNetRPCInterceptorFunc: middleware.GetNetRPCInterceptor, + EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c), } } diff --git a/agent/consul/options.go b/agent/consul/options.go index 3440b02450..e253864a56 100644 --- a/agent/consul/options.go +++ b/agent/consul/options.go @@ -1,11 +1,13 @@ package consul import ( + "github.com/hashicorp/consul-net-rpc/net/rpc" "github.com/hashicorp/go-hclog" "google.golang.org/grpc" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/router" + "github.com/hashicorp/consul/agent/rpc/middleware" "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/tlsutil" ) @@ -18,6 +20,12 @@ type Deps struct { ConnPool *pool.ConnPool GRPCConnPool GRPCClientConner LeaderForwarder LeaderForwarder + // GetNetRPCInterceptorFunc, if not nil, sets the net/rpc rpc.ServerServiceCallInterceptor on + // the server side to record metrics around the RPC requests. If nil, no interceptor is added to + // the rpc server. + GetNetRPCInterceptorFunc func(recorder *middleware.RequestRecorder) rpc.ServerServiceCallInterceptor + // NewRequestRecorderFunc provides a middleware.RequestRecorder for the server to use; it cannot be nil + NewRequestRecorderFunc func(logger hclog.Logger) *middleware.RequestRecorder EnterpriseDeps } diff --git a/agent/consul/server.go b/agent/consul/server.go index c48204bb51..3ec3d61dde 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -379,24 +379,40 @@ func NewServer(config *Config, flat Deps, publicGRPCServer *grpc.Server) (*Serve serverLogger := flat.Logger.NamedIntercept(logging.ConsulServer) loggers := newLoggerStore(serverLogger) - recorder := middleware.NewRequestRecorder(serverLogger) + var recorder *middleware.RequestRecorder + if flat.NewRequestRecorderFunc == nil { + return nil, fmt.Errorf("cannot initialize server without an RPC request recorder provider") + } + recorder = flat.NewRequestRecorderFunc(serverLogger) + if recorder == nil { + return nil, fmt.Errorf("cannot initialize server without a non nil RPC request recorder") + } + + var rpcServer, insecureRPCServer *rpc.Server + if flat.GetNetRPCInterceptorFunc == nil { + rpcServer = rpc.NewServer() + insecureRPCServer = rpc.NewServer() + } else { + rpcServer = rpc.NewServerWithOpts(rpc.WithServerServiceCallInterceptor(flat.GetNetRPCInterceptorFunc(recorder))) + insecureRPCServer = rpc.NewServerWithOpts(rpc.WithServerServiceCallInterceptor(flat.GetNetRPCInterceptorFunc(recorder))) + } + // Create server. s := &Server{ - config: config, - tokens: flat.Tokens, - connPool: flat.ConnPool, - grpcConnPool: flat.GRPCConnPool, - eventChLAN: make(chan serf.Event, serfEventChSize), - eventChWAN: make(chan serf.Event, serfEventChSize), - logger: serverLogger, - loggers: loggers, - leaveCh: make(chan struct{}), - reconcileCh: make(chan serf.Member, reconcileChSize), - router: flat.Router, - rpcRecorder: recorder, - // TODO(rpc-metrics-improv): consider pulling out the interceptor from config in order to isolate testing - rpcServer: rpc.NewServerWithOpts(rpc.WithServerServiceCallInterceptor(middleware.GetNetRPCInterceptor(recorder))), - insecureRPCServer: rpc.NewServerWithOpts(rpc.WithServerServiceCallInterceptor(middleware.GetNetRPCInterceptor(recorder))), + config: config, + tokens: flat.Tokens, + connPool: flat.ConnPool, + grpcConnPool: flat.GRPCConnPool, + eventChLAN: make(chan serf.Event, serfEventChSize), + eventChWAN: make(chan serf.Event, serfEventChSize), + logger: serverLogger, + loggers: loggers, + leaveCh: make(chan struct{}), + reconcileCh: make(chan serf.Member, reconcileChSize), + router: flat.Router, + rpcRecorder: recorder, + rpcServer: rpcServer, + insecureRPCServer: insecureRPCServer, tlsConfigurator: flat.TLSConfigurator, publicGRPCServer: publicGRPCServer, reassertLeaderCh: make(chan chan error), diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index 6f953dd1c7..5c06fb4d96 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -5,16 +5,20 @@ import ( "fmt" "net" "os" + "reflect" "strings" "sync/atomic" "testing" "time" + "github.com/armon/go-metrics" "github.com/google/tcpproxy" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/memberlist" "github.com/hashicorp/raft" "google.golang.org/grpc" + "github.com/hashicorp/consul/agent/rpc/middleware" "github.com/hashicorp/consul/ipaddr" "github.com/hashicorp/go-uuid" @@ -254,6 +258,10 @@ func testACLServerWithConfig(t *testing.T, cb func(*Config), initReplicationToke } func newServer(t *testing.T, c *Config) (*Server, error) { + return newServerWithDeps(t, c, newDefaultDeps(t, c)) +} + +func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) { // chain server up notification oldNotify := c.NotifyListen up := make(chan struct{}) @@ -264,7 +272,8 @@ func newServer(t *testing.T, c *Config) (*Server, error) { } } - srv, err := NewServer(c, newDefaultDeps(t, c), grpc.NewServer()) + srv, err := NewServer(c, deps, grpc.NewServer()) + if err != nil { return nil, err } @@ -1130,6 +1139,221 @@ func TestServer_RPC(t *testing.T) { } } +// TestServer_RPC_MetricsIntercept_Off proves that we can turn off net/rpc interceptors all together. +func TestServer_RPC_MetricsIntercept_Off(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + storage := make(map[string]float32) + keyMakingFunc := func(key []string, labels []metrics.Label) string { + allKey := strings.Join(key, "+") + + for _, label := range labels { + if label.Name == "method" { + allKey = allKey + "+" + label.Value + } + } + + return allKey + } + + simpleRecorderFunc := func(key []string, val float32, labels []metrics.Label) { + storage[keyMakingFunc(key, labels)] = val + } + + t.Run("test no net/rpc interceptor metric with nil func", func(t *testing.T) { + _, conf := testServerConfig(t) + deps := newDefaultDeps(t, conf) + + // "disable" metrics net/rpc interceptor + deps.GetNetRPCInterceptorFunc = nil + // "hijack" the rpc recorder for asserts; + // note that there will be "internal" net/rpc calls made + // that will still show up; those don't go thru the net/rpc interceptor; + // see consul.agent.rpc.middleware.RPCTypeInternal for context + deps.NewRequestRecorderFunc = func(logger hclog.Logger) *middleware.RequestRecorder { + return &middleware.RequestRecorder{ + Logger: hclog.NewInterceptLogger(&hclog.LoggerOptions{}), + RecorderFunc: simpleRecorderFunc, + } + } + + s1, err := NewServer(conf, deps, grpc.NewServer()) + if err != nil { + t.Fatalf("err: %v", err) + } + t.Cleanup(func() { s1.Shutdown() }) + + var out struct{} + if err := s1.RPC("Status.Ping", struct{}{}, &out); err != nil { + t.Fatalf("err: %v", err) + } + + key := keyMakingFunc(middleware.OneTwelveRPCSummary[0].Name, []metrics.Label{{Name: "method", Value: "Status.Ping"}}) + + if _, ok := storage[key]; ok { + t.Fatalf("Did not expect to find key %s in the metrics log, ", key) + } + }) + + t.Run("test no net/rpc interceptor metric with func that gives nil", func(t *testing.T) { + _, conf := testServerConfig(t) + deps := newDefaultDeps(t, conf) + + // "hijack" the rpc recorder for asserts; + // note that there will be "internal" net/rpc calls made + // that will still show up; those don't go thru the net/rpc interceptor; + // see consul.agent.rpc.middleware.RPCTypeInternal for context + deps.NewRequestRecorderFunc = func(logger hclog.Logger) *middleware.RequestRecorder { + return &middleware.RequestRecorder{ + Logger: hclog.NewInterceptLogger(&hclog.LoggerOptions{}), + RecorderFunc: simpleRecorderFunc, + } + } + + deps.GetNetRPCInterceptorFunc = func(recorder *middleware.RequestRecorder) rpc.ServerServiceCallInterceptor { + return nil + } + + s2, err := NewServer(conf, deps, grpc.NewServer()) + if err != nil { + t.Fatalf("err: %v", err) + } + t.Cleanup(func() { s2.Shutdown() }) + if err != nil { + t.Fatalf("err: %v", err) + } + + var out struct{} + if err := s2.RPC("Status.Ping", struct{}{}, &out); err != nil { + t.Fatalf("err: %v", err) + } + + key := keyMakingFunc(middleware.OneTwelveRPCSummary[0].Name, []metrics.Label{{Name: "method", Value: "Status.Ping"}}) + + if _, ok := storage[key]; ok { + t.Fatalf("Did not expect to find key %s in the metrics log, ", key) + } + }) +} + +// TestServer_RPC_RequestRecorder proves that we cannot make a server without a valid RequestRecorder provider func +// or a non nil RequestRecorder. +func TestServer_RPC_RequestRecorder(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + t.Run("test nil func provider", func(t *testing.T) { + _, conf := testServerConfig(t) + deps := newDefaultDeps(t, conf) + deps.NewRequestRecorderFunc = nil + + s1, err := NewServer(conf, deps, grpc.NewServer()) + + require.Error(t, err, "need err when provider func is nil") + require.Equal(t, err.Error(), "cannot initialize server without an RPC request recorder provider") + + t.Cleanup(func() { + if s1 != nil { + s1.Shutdown() + } + }) + }) + + t.Run("test nil RequestRecorder", func(t *testing.T) { + _, conf := testServerConfig(t) + deps := newDefaultDeps(t, conf) + deps.NewRequestRecorderFunc = func(logger hclog.Logger) *middleware.RequestRecorder { + return nil + } + + s2, err := NewServer(conf, deps, grpc.NewServer()) + + require.Error(t, err, "need err when RequestRecorder is nil") + require.Equal(t, err.Error(), "cannot initialize server without a non nil RPC request recorder") + + t.Cleanup(func() { + if s2 != nil { + s2.Shutdown() + } + }) + }) +} + +// TestServer_RPC_MetricsIntercept mocks a request recorder and asserts that RPC calls are observed. +func TestServer_RPC_MetricsIntercept(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + _, conf := testServerConfig(t) + deps := newDefaultDeps(t, conf) + + // The method used to record metric observations here is similar to that used in + // interceptors_test.go; at present, we don't have a need to lock yet but if we do + // we can imitate that set up further or just factor it out as a "mock" metrics backend + storage := make(map[string]float32) + keyMakingFunc := func(key []string, labels []metrics.Label) string { + allKey := strings.Join(key, "+") + + for _, label := range labels { + allKey = allKey + "+" + label.Value + } + + return allKey + } + + simpleRecorderFunc := func(key []string, val float32, labels []metrics.Label) { + storage[keyMakingFunc(key, labels)] = val + } + deps.NewRequestRecorderFunc = func(logger hclog.Logger) *middleware.RequestRecorder { + return &middleware.RequestRecorder{ + Logger: hclog.NewInterceptLogger(&hclog.LoggerOptions{}), + RecorderFunc: simpleRecorderFunc, + } + } + + deps.GetNetRPCInterceptorFunc = func(recorder *middleware.RequestRecorder) rpc.ServerServiceCallInterceptor { + return func(reqServiceMethod string, argv, replyv reflect.Value, handler func() error) { + reqStart := time.Now() + + err := handler() + + recorder.Record(reqServiceMethod, "test", reqStart, argv.Interface(), err != nil) + } + } + + s, err := newServerWithDeps(t, conf, deps) + if err != nil { + t.Fatalf("err: %v", err) + } + defer s.Shutdown() + testrpc.WaitForTestAgent(t, s.RPC, "dc1") + + // asserts + t.Run("test happy path for metrics interceptor", func(t *testing.T) { + var out struct{} + if err := s.RPC("Status.Ping", struct{}{}, &out); err != nil { + t.Fatalf("err: %v", err) + } + + expectedLabels := []metrics.Label{ + {Name: "method", Value: "Status.Ping"}, + {Name: "errored", Value: "false"}, + {Name: "request_type", Value: "read"}, + {Name: "rpc_type", Value: "test"}, + } + + key := keyMakingFunc(middleware.OneTwelveRPCSummary[0].Name, expectedLabels) + + if _, ok := storage[key]; !ok { + t.Fatalf("Did not find key %s in the metrics log, ", key) + } + }) +} + func TestServer_JoinLAN_TLS(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") diff --git a/agent/rpc/middleware/interceptors.go b/agent/rpc/middleware/interceptors.go index e57ab647dc..ba6747c3a7 100644 --- a/agent/rpc/middleware/interceptors.go +++ b/agent/rpc/middleware/interceptors.go @@ -34,11 +34,11 @@ var OneTwelveRPCSummary = []prometheus.SummaryDefinition{ type RequestRecorder struct { Logger hclog.Logger - recorderFunc func(key []string, val float32, labels []metrics.Label) + RecorderFunc func(key []string, val float32, labels []metrics.Label) } func NewRequestRecorder(logger hclog.Logger) *RequestRecorder { - return &RequestRecorder{Logger: logger, recorderFunc: metrics.AddSampleWithLabels} + return &RequestRecorder{Logger: logger, RecorderFunc: metrics.AddSampleWithLabels} } func (r *RequestRecorder) Record(requestName string, rpcType string, start time.Time, request interface{}, respErrored bool) { @@ -53,7 +53,7 @@ func (r *RequestRecorder) Record(requestName string, rpcType string, start time. } // math.MaxInt64 < math.MaxFloat32 is true so we should be good! - r.recorderFunc(metricRPCRequest, float32(elapsed), labels) + r.RecorderFunc(metricRPCRequest, float32(elapsed), labels) r.Logger.Trace(requestLogName, "method", requestName, "errored", respErrored, diff --git a/agent/rpc/middleware/interceptors_test.go b/agent/rpc/middleware/interceptors_test.go index 23d7649622..63fbefecb3 100644 --- a/agent/rpc/middleware/interceptors_test.go +++ b/agent/rpc/middleware/interceptors_test.go @@ -18,7 +18,7 @@ type obs struct { labels []metrics.Label } -// recorderStore acts as an in-mem mock storage for all the RequestRecorder.Record() recorderFunc calls. +// recorderStore acts as an in-mem mock storage for all the RequestRecorder.Record() RecorderFunc calls. type recorderStore struct { lock sync.Mutex store map[string]obs @@ -59,9 +59,11 @@ func (wr writeRequest) IsRead() bool { // TestRequestRecorder_SimpleOK tests that the RequestRecorder can record a simple request. func TestRequestRecorder_SimpleOK(t *testing.T) { + t.Parallel() + r := RequestRecorder{ Logger: hclog.NewInterceptLogger(&hclog.LoggerOptions{}), - recorderFunc: simpleRecorderFunc, + RecorderFunc: simpleRecorderFunc, } start := time.Now() @@ -82,9 +84,11 @@ func TestRequestRecorder_SimpleOK(t *testing.T) { // TestRequestRecorder_ReadRequest tests that RequestRecorder can record a read request AND a responseErrored arg. func TestRequestRecorder_ReadRequest(t *testing.T) { + t.Parallel() + r := RequestRecorder{ Logger: hclog.NewInterceptLogger(&hclog.LoggerOptions{}), - recorderFunc: simpleRecorderFunc, + RecorderFunc: simpleRecorderFunc, } start := time.Now() @@ -104,9 +108,11 @@ func TestRequestRecorder_ReadRequest(t *testing.T) { // TestRequestRecorder_WriteRequest tests that RequestRecorder can record a write request. func TestRequestRecorder_WriteRequest(t *testing.T) { + t.Parallel() + r := RequestRecorder{ Logger: hclog.NewInterceptLogger(&hclog.LoggerOptions{}), - recorderFunc: simpleRecorderFunc, + RecorderFunc: simpleRecorderFunc, } start := time.Now() diff --git a/agent/setup.go b/agent/setup.go index 0799c472a3..322f170b25 100644 --- a/agent/setup.go +++ b/agent/setup.go @@ -25,6 +25,7 @@ import ( "github.com/hashicorp/consul/agent/local" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/router" + "github.com/hashicorp/consul/agent/rpc/middleware" "github.com/hashicorp/consul/agent/submatview" "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/agent/xds" @@ -151,6 +152,9 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error) return d, err } + d.NewRequestRecorderFunc = middleware.NewRequestRecorder + d.GetNetRPCInterceptorFunc = middleware.GetNetRPCInterceptor + return d, nil }