agent: make the RPC endpoint overwrite mechanism more transparent

This patch hides the RPC handler overwrite mechanism from the
rest of the code so that it works in all cases and that there
is no cooperation required from the tested code, i.e. we can
drop a.getEndpoint().
This commit is contained in:
Frank Schroeder 2017-06-16 09:54:09 +02:00 committed by Frank Schröder
parent e15f9f9d90
commit 2b41f2e3a3
8 changed files with 68 additions and 83 deletions

View File

@ -174,7 +174,7 @@ func (m *aclManager) lookupACL(a *Agent, id string) (acl.ACL, error) {
args.ETag = cached.ETag
}
var reply structs.ACLPolicy
err := a.RPC(a.getEndpoint("ACL")+".GetPolicy", &args, &reply)
err := a.RPC("ACL.GetPolicy", &args, &reply)
if err != nil {
if strings.Contains(err.Error(), aclDisabled) {
a.logger.Printf("[DEBUG] agent: ACLs disabled on servers, will check again after %s", a.config.ACLDisabledTTL)

View File

@ -47,7 +47,7 @@ func TestACL_Version8(t *testing.T) {
defer a.Shutdown()
m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -70,7 +70,7 @@ func TestACL_Disabled(t *testing.T) {
defer a.Shutdown()
m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -123,7 +123,7 @@ func TestACL_Special_IDs(t *testing.T) {
defer a.Shutdown()
m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -176,7 +176,7 @@ func TestACL_Down_Deny(t *testing.T) {
defer a.Shutdown()
m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -206,7 +206,7 @@ func TestACL_Down_Allow(t *testing.T) {
defer a.Shutdown()
m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -236,7 +236,7 @@ func TestACL_Down_Extend(t *testing.T) {
defer a.Shutdown()
m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -313,7 +313,7 @@ func TestACL_Cache(t *testing.T) {
defer a.Shutdown()
m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -495,7 +495,7 @@ func TestACL_vetServiceRegister(t *testing.T) {
defer a.Shutdown()
m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -541,7 +541,7 @@ func TestACL_vetServiceUpdate(t *testing.T) {
defer a.Shutdown()
m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -577,7 +577,7 @@ func TestACL_vetCheckRegister(t *testing.T) {
defer a.Shutdown()
m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -660,7 +660,7 @@ func TestACL_vetCheckUpdate(t *testing.T) {
defer a.Shutdown()
m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -716,7 +716,7 @@ func TestACL_filterMembers(t *testing.T) {
defer a.Shutdown()
m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -752,7 +752,7 @@ func TestACL_filterServices(t *testing.T) {
defer a.Shutdown()
m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -783,7 +783,7 @@ func TestACL_filterChecks(t *testing.T) {
defer a.Shutdown()
m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil {
if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -14,7 +14,6 @@ import (
"net/http"
"os"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
@ -146,9 +145,9 @@ type Agent struct {
// attempts.
retryJoinCh chan error
// endpoints lets you override RPC endpoints for testing. Not all
// agent methods use this, so use with care and never override
// outside of a unit test.
// endpoints maps unique RPC endpoint names to common ones
// to allow overriding of RPC handlers since the golang
// net/rpc server does not allow this.
endpoints map[string]string
endpointsLock sync.RWMutex
@ -1068,9 +1067,34 @@ LOAD:
return nil
}
// RegisterEndpoint registers a handler for the consul RPC server
// under a unique name while making it accessible under the provided
// name. This allows overwriting handlers for the golang net/rpc
// service which does not allow this.
func (a *Agent) RegisterEndpoint(name string, handler interface{}) error {
srv, ok := a.delegate.(*consul.Server)
if !ok {
panic("agent must be a server")
}
realname := fmt.Sprintf("%s-%d", name, time.Now().UnixNano())
a.endpointsLock.Lock()
a.endpoints[name] = realname
a.endpointsLock.Unlock()
return srv.RegisterEndpoint(realname, handler)
}
// RPC is used to make an RPC call to the Consul servers
// This allows the agent to implement the Consul.Interface
func (a *Agent) RPC(method string, args interface{}, reply interface{}) error {
a.endpointsLock.Lock()
// fast path: only translate if there are overrides
if len(a.endpoints) > 0 {
p := strings.SplitN(method, ".", 2)
if e := a.endpoints[p[0]]; e != "" {
method = e + "." + p[1]
}
}
a.endpointsLock.Unlock()
return a.delegate.RPC(method, args, reply)
}
@ -2255,37 +2279,6 @@ func (a *Agent) DisableNodeMaintenance() {
a.logger.Printf("[INFO] agent: Node left maintenance mode")
}
// InjectEndpoint overrides the given endpoint with a substitute one. Note
// that not all agent methods use this mechanism, and that is should only
// be used for testing.
func (a *Agent) InjectEndpoint(endpoint string, handler interface{}) error {
srv, ok := a.delegate.(*consul.Server)
if !ok {
return fmt.Errorf("agent must be a server")
}
if err := srv.InjectEndpoint(handler); err != nil {
return err
}
name := reflect.Indirect(reflect.ValueOf(handler)).Type().Name()
a.endpointsLock.Lock()
a.endpoints[endpoint] = name
a.endpointsLock.Unlock()
a.logger.Printf("[WARN] agent: endpoint injected; this should only be used for testing")
return nil
}
// getEndpoint returns the endpoint name to use for the given endpoint,
// which may be overridden.
func (a *Agent) getEndpoint(endpoint string) string {
a.endpointsLock.RLock()
defer a.endpointsLock.RUnlock()
if override, ok := a.endpoints[endpoint]; ok {
return override
}
return endpoint
}
func (a *Agent) ReloadConfig(newCfg *Config) (bool, error) {
var errs error

View File

@ -977,10 +977,10 @@ func (s *Server) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io
return nil
}
// InjectEndpoint is used to substitute an endpoint for testing.
func (s *Server) InjectEndpoint(endpoint interface{}) error {
// RegisterEndpoint is used to substitute an endpoint for testing.
func (s *Server) RegisterEndpoint(name string, handler interface{}) error {
s.logger.Printf("[WARN] consul: endpoint injected; this should only be used for testing")
return s.rpcServer.Register(endpoint)
return s.rpcServer.RegisterName(name, handler)
}
// Stats is used to return statistics for debugging and insight

View File

@ -695,10 +695,9 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, req,
// likely work in practice, like 10*maxUDPAnswerLimit which should help
// reduce bandwidth if there are thousands of nodes available.
endpoint := d.agent.getEndpoint(preparedQueryEndpoint)
var out structs.PreparedQueryExecuteResponse
RPC:
if err := d.agent.RPC(endpoint+".Execute", &args, &out); err != nil {
if err := d.agent.RPC("PreparedQuery.Execute", &args, &out); err != nil {
// If they give a bogus query name, treat that as a name error,
// not a full on server error. We have to use a string compare
// here since the RPC layer loses the type information.

View File

@ -3932,7 +3932,7 @@ func TestDNS_PreparedQuery_AllowStale(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -4013,7 +4013,7 @@ func TestDNS_PreparedQuery_AgentSource(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -37,8 +37,7 @@ func (s *HTTPServer) preparedQueryCreate(resp http.ResponseWriter, req *http.Req
}
var reply string
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
if err := s.agent.RPC(endpoint+".Apply", &args, &reply); err != nil {
if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil {
return nil, err
}
return preparedQueryCreateResponse{reply}, nil
@ -52,8 +51,7 @@ func (s *HTTPServer) preparedQueryList(resp http.ResponseWriter, req *http.Reque
}
var reply structs.IndexedPreparedQueries
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
if err := s.agent.RPC(endpoint+".List", &args, &reply); err != nil {
if err := s.agent.RPC("PreparedQuery.List", &args, &reply); err != nil {
return nil, err
}
@ -110,8 +108,7 @@ func (s *HTTPServer) preparedQueryExecute(id string, resp http.ResponseWriter, r
}
var reply structs.PreparedQueryExecuteResponse
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
if err := s.agent.RPC(endpoint+".Execute", &args, &reply); err != nil {
if err := s.agent.RPC("PreparedQuery.Execute", &args, &reply); err != nil {
// We have to check the string since the RPC sheds
// the specific error type.
if err.Error() == consul.ErrQueryNotFound.Error() {
@ -155,8 +152,7 @@ func (s *HTTPServer) preparedQueryExplain(id string, resp http.ResponseWriter, r
}
var reply structs.PreparedQueryExplainResponse
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
if err := s.agent.RPC(endpoint+".Explain", &args, &reply); err != nil {
if err := s.agent.RPC("PreparedQuery.Explain", &args, &reply); err != nil {
// We have to check the string since the RPC sheds
// the specific error type.
if err.Error() == consul.ErrQueryNotFound.Error() {
@ -179,8 +175,7 @@ func (s *HTTPServer) preparedQueryGet(id string, resp http.ResponseWriter, req *
}
var reply structs.IndexedPreparedQueries
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
if err := s.agent.RPC(endpoint+".Get", &args, &reply); err != nil {
if err := s.agent.RPC("PreparedQuery.Get", &args, &reply); err != nil {
// We have to check the string since the RPC sheds
// the specific error type.
if err.Error() == consul.ErrQueryNotFound.Error() {
@ -212,8 +207,7 @@ func (s *HTTPServer) preparedQueryUpdate(id string, resp http.ResponseWriter, re
args.Query.ID = id
var reply string
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
if err := s.agent.RPC(endpoint+".Apply", &args, &reply); err != nil {
if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil {
return nil, err
}
return nil, nil
@ -231,8 +225,7 @@ func (s *HTTPServer) preparedQueryDelete(id string, resp http.ResponseWriter, re
s.parseToken(req, &args.Token)
var reply string
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
if err := s.agent.RPC(endpoint+".Apply", &args, &reply); err != nil {
if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil {
return nil, err
}
return nil, nil

View File

@ -74,7 +74,7 @@ func TestPreparedQuery_Create(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -159,7 +159,7 @@ func TestPreparedQuery_List(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -192,7 +192,7 @@ func TestPreparedQuery_List(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -242,7 +242,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -275,7 +275,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -331,7 +331,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -365,7 +365,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -415,7 +415,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -479,7 +479,7 @@ func TestPreparedQuery_Explain(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -552,7 +552,7 @@ func TestPreparedQuery_Get(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -617,7 +617,7 @@ func TestPreparedQuery_Update(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -695,7 +695,7 @@ func TestPreparedQuery_Delete(t *testing.T) {
defer a.Shutdown()
m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}