diff --git a/agent/acl.go b/agent/acl.go index 871488b62f..b84f40ffca 100644 --- a/agent/acl.go +++ b/agent/acl.go @@ -174,7 +174,7 @@ func (m *aclManager) lookupACL(a *Agent, id string) (acl.ACL, error) { args.ETag = cached.ETag } var reply structs.ACLPolicy - err := a.RPC(a.getEndpoint("ACL")+".GetPolicy", &args, &reply) + err := a.RPC("ACL.GetPolicy", &args, &reply) if err != nil { if strings.Contains(err.Error(), aclDisabled) { a.logger.Printf("[DEBUG] agent: ACLs disabled on servers, will check again after %s", a.config.ACLDisabledTTL) diff --git a/agent/acl_test.go b/agent/acl_test.go index 59bf99b1a7..58f0f76d14 100644 --- a/agent/acl_test.go +++ b/agent/acl_test.go @@ -47,7 +47,7 @@ func TestACL_Version8(t *testing.T) { defer a.Shutdown() m := MockServer{} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -70,7 +70,7 @@ func TestACL_Disabled(t *testing.T) { defer a.Shutdown() m := MockServer{} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -123,7 +123,7 @@ func TestACL_Special_IDs(t *testing.T) { defer a.Shutdown() m := MockServer{} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -176,7 +176,7 @@ func TestACL_Down_Deny(t *testing.T) { defer a.Shutdown() m := MockServer{} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -206,7 +206,7 @@ func TestACL_Down_Allow(t *testing.T) { defer a.Shutdown() m := MockServer{} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -236,7 +236,7 @@ func TestACL_Down_Extend(t *testing.T) { defer a.Shutdown() m := MockServer{} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -313,7 +313,7 @@ func TestACL_Cache(t *testing.T) { defer a.Shutdown() m := MockServer{} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -495,7 +495,7 @@ func TestACL_vetServiceRegister(t *testing.T) { defer a.Shutdown() m := MockServer{catalogPolicy} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -541,7 +541,7 @@ func TestACL_vetServiceUpdate(t *testing.T) { defer a.Shutdown() m := MockServer{catalogPolicy} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -577,7 +577,7 @@ func TestACL_vetCheckRegister(t *testing.T) { defer a.Shutdown() m := MockServer{catalogPolicy} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -660,7 +660,7 @@ func TestACL_vetCheckUpdate(t *testing.T) { defer a.Shutdown() m := MockServer{catalogPolicy} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -716,7 +716,7 @@ func TestACL_filterMembers(t *testing.T) { defer a.Shutdown() m := MockServer{catalogPolicy} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -752,7 +752,7 @@ func TestACL_filterServices(t *testing.T) { defer a.Shutdown() m := MockServer{catalogPolicy} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } @@ -783,7 +783,7 @@ func TestACL_filterChecks(t *testing.T) { defer a.Shutdown() m := MockServer{catalogPolicy} - if err := a.InjectEndpoint("ACL", &m); err != nil { + if err := a.RegisterEndpoint("ACL", &m); err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/agent.go b/agent/agent.go index 205d44e0cd..6cef976957 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -14,7 +14,6 @@ import ( "net/http" "os" "path/filepath" - "reflect" "regexp" "strconv" "strings" @@ -146,9 +145,9 @@ type Agent struct { // attempts. retryJoinCh chan error - // endpoints lets you override RPC endpoints for testing. Not all - // agent methods use this, so use with care and never override - // outside of a unit test. + // endpoints maps unique RPC endpoint names to common ones + // to allow overriding of RPC handlers since the golang + // net/rpc server does not allow this. endpoints map[string]string endpointsLock sync.RWMutex @@ -1068,9 +1067,34 @@ LOAD: return nil } +// RegisterEndpoint registers a handler for the consul RPC server +// under a unique name while making it accessible under the provided +// name. This allows overwriting handlers for the golang net/rpc +// service which does not allow this. +func (a *Agent) RegisterEndpoint(name string, handler interface{}) error { + srv, ok := a.delegate.(*consul.Server) + if !ok { + panic("agent must be a server") + } + realname := fmt.Sprintf("%s-%d", name, time.Now().UnixNano()) + a.endpointsLock.Lock() + a.endpoints[name] = realname + a.endpointsLock.Unlock() + return srv.RegisterEndpoint(realname, handler) +} + // RPC is used to make an RPC call to the Consul servers // This allows the agent to implement the Consul.Interface func (a *Agent) RPC(method string, args interface{}, reply interface{}) error { + a.endpointsLock.Lock() + // fast path: only translate if there are overrides + if len(a.endpoints) > 0 { + p := strings.SplitN(method, ".", 2) + if e := a.endpoints[p[0]]; e != "" { + method = e + "." + p[1] + } + } + a.endpointsLock.Unlock() return a.delegate.RPC(method, args, reply) } @@ -2255,37 +2279,6 @@ func (a *Agent) DisableNodeMaintenance() { a.logger.Printf("[INFO] agent: Node left maintenance mode") } -// InjectEndpoint overrides the given endpoint with a substitute one. Note -// that not all agent methods use this mechanism, and that is should only -// be used for testing. -func (a *Agent) InjectEndpoint(endpoint string, handler interface{}) error { - srv, ok := a.delegate.(*consul.Server) - if !ok { - return fmt.Errorf("agent must be a server") - } - if err := srv.InjectEndpoint(handler); err != nil { - return err - } - name := reflect.Indirect(reflect.ValueOf(handler)).Type().Name() - a.endpointsLock.Lock() - a.endpoints[endpoint] = name - a.endpointsLock.Unlock() - - a.logger.Printf("[WARN] agent: endpoint injected; this should only be used for testing") - return nil -} - -// getEndpoint returns the endpoint name to use for the given endpoint, -// which may be overridden. -func (a *Agent) getEndpoint(endpoint string) string { - a.endpointsLock.RLock() - defer a.endpointsLock.RUnlock() - if override, ok := a.endpoints[endpoint]; ok { - return override - } - return endpoint -} - func (a *Agent) ReloadConfig(newCfg *Config) (bool, error) { var errs error diff --git a/agent/consul/server.go b/agent/consul/server.go index 6047689b22..fbe19564b5 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -977,10 +977,10 @@ func (s *Server) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io return nil } -// InjectEndpoint is used to substitute an endpoint for testing. -func (s *Server) InjectEndpoint(endpoint interface{}) error { +// RegisterEndpoint is used to substitute an endpoint for testing. +func (s *Server) RegisterEndpoint(name string, handler interface{}) error { s.logger.Printf("[WARN] consul: endpoint injected; this should only be used for testing") - return s.rpcServer.Register(endpoint) + return s.rpcServer.RegisterName(name, handler) } // Stats is used to return statistics for debugging and insight diff --git a/agent/dns.go b/agent/dns.go index e9dbdb878c..dbceb17964 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -695,10 +695,9 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, req, // likely work in practice, like 10*maxUDPAnswerLimit which should help // reduce bandwidth if there are thousands of nodes available. - endpoint := d.agent.getEndpoint(preparedQueryEndpoint) var out structs.PreparedQueryExecuteResponse RPC: - if err := d.agent.RPC(endpoint+".Execute", &args, &out); err != nil { + if err := d.agent.RPC("PreparedQuery.Execute", &args, &out); err != nil { // If they give a bogus query name, treat that as a name error, // not a full on server error. We have to use a string compare // here since the RPC layer loses the type information. diff --git a/agent/dns_test.go b/agent/dns_test.go index fd69dcb5d6..2a0bd93027 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -3932,7 +3932,7 @@ func TestDNS_PreparedQuery_AllowStale(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -4013,7 +4013,7 @@ func TestDNS_PreparedQuery_AgentSource(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/prepared_query_endpoint.go b/agent/prepared_query_endpoint.go index bbfc9e6ea7..bda60706e1 100644 --- a/agent/prepared_query_endpoint.go +++ b/agent/prepared_query_endpoint.go @@ -37,8 +37,7 @@ func (s *HTTPServer) preparedQueryCreate(resp http.ResponseWriter, req *http.Req } var reply string - endpoint := s.agent.getEndpoint(preparedQueryEndpoint) - if err := s.agent.RPC(endpoint+".Apply", &args, &reply); err != nil { + if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil { return nil, err } return preparedQueryCreateResponse{reply}, nil @@ -52,8 +51,7 @@ func (s *HTTPServer) preparedQueryList(resp http.ResponseWriter, req *http.Reque } var reply structs.IndexedPreparedQueries - endpoint := s.agent.getEndpoint(preparedQueryEndpoint) - if err := s.agent.RPC(endpoint+".List", &args, &reply); err != nil { + if err := s.agent.RPC("PreparedQuery.List", &args, &reply); err != nil { return nil, err } @@ -110,8 +108,7 @@ func (s *HTTPServer) preparedQueryExecute(id string, resp http.ResponseWriter, r } var reply structs.PreparedQueryExecuteResponse - endpoint := s.agent.getEndpoint(preparedQueryEndpoint) - if err := s.agent.RPC(endpoint+".Execute", &args, &reply); err != nil { + if err := s.agent.RPC("PreparedQuery.Execute", &args, &reply); err != nil { // We have to check the string since the RPC sheds // the specific error type. if err.Error() == consul.ErrQueryNotFound.Error() { @@ -155,8 +152,7 @@ func (s *HTTPServer) preparedQueryExplain(id string, resp http.ResponseWriter, r } var reply structs.PreparedQueryExplainResponse - endpoint := s.agent.getEndpoint(preparedQueryEndpoint) - if err := s.agent.RPC(endpoint+".Explain", &args, &reply); err != nil { + if err := s.agent.RPC("PreparedQuery.Explain", &args, &reply); err != nil { // We have to check the string since the RPC sheds // the specific error type. if err.Error() == consul.ErrQueryNotFound.Error() { @@ -179,8 +175,7 @@ func (s *HTTPServer) preparedQueryGet(id string, resp http.ResponseWriter, req * } var reply structs.IndexedPreparedQueries - endpoint := s.agent.getEndpoint(preparedQueryEndpoint) - if err := s.agent.RPC(endpoint+".Get", &args, &reply); err != nil { + if err := s.agent.RPC("PreparedQuery.Get", &args, &reply); err != nil { // We have to check the string since the RPC sheds // the specific error type. if err.Error() == consul.ErrQueryNotFound.Error() { @@ -212,8 +207,7 @@ func (s *HTTPServer) preparedQueryUpdate(id string, resp http.ResponseWriter, re args.Query.ID = id var reply string - endpoint := s.agent.getEndpoint(preparedQueryEndpoint) - if err := s.agent.RPC(endpoint+".Apply", &args, &reply); err != nil { + if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil { return nil, err } return nil, nil @@ -231,8 +225,7 @@ func (s *HTTPServer) preparedQueryDelete(id string, resp http.ResponseWriter, re s.parseToken(req, &args.Token) var reply string - endpoint := s.agent.getEndpoint(preparedQueryEndpoint) - if err := s.agent.RPC(endpoint+".Apply", &args, &reply); err != nil { + if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil { return nil, err } return nil, nil diff --git a/agent/prepared_query_endpoint_test.go b/agent/prepared_query_endpoint_test.go index 0d1ef6d5c0..171599fa0e 100644 --- a/agent/prepared_query_endpoint_test.go +++ b/agent/prepared_query_endpoint_test.go @@ -74,7 +74,7 @@ func TestPreparedQuery_Create(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -159,7 +159,7 @@ func TestPreparedQuery_List(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -192,7 +192,7 @@ func TestPreparedQuery_List(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -242,7 +242,7 @@ func TestPreparedQuery_Execute(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -275,7 +275,7 @@ func TestPreparedQuery_Execute(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -331,7 +331,7 @@ func TestPreparedQuery_Execute(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -365,7 +365,7 @@ func TestPreparedQuery_Execute(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -415,7 +415,7 @@ func TestPreparedQuery_Execute(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -479,7 +479,7 @@ func TestPreparedQuery_Explain(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -552,7 +552,7 @@ func TestPreparedQuery_Get(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -617,7 +617,7 @@ func TestPreparedQuery_Update(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -695,7 +695,7 @@ func TestPreparedQuery_Delete(t *testing.T) { defer a.Shutdown() m := MockPreparedQuery{} - if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { + if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) }