Adds a slightly more flexible mock system so we can test DNS.

This commit is contained in:
James Phillips 2015-11-12 09:19:33 -08:00
parent da20e6668b
commit 5e7523ea4b
3 changed files with 48 additions and 24 deletions

View File

@ -9,6 +9,7 @@ import (
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"regexp" "regexp"
"strconv" "strconv"
"sync" "sync"
@ -104,6 +105,11 @@ type Agent struct {
shutdown bool shutdown bool
shutdownCh chan struct{} shutdownCh chan struct{}
shutdownLock sync.Mutex shutdownLock sync.Mutex
// endpoints lets you override RPC endpoints for testing. Not all
// agent methods use this, so use with care and never override
// outside of a unit test.
endpoints map[string]string
} }
// Create is used to create a new Agent. Returns // Create is used to create a new Agent. Returns
@ -158,6 +164,7 @@ func Create(config *Config, logOutput io.Writer) (*Agent, error) {
eventCh: make(chan serf.UserEvent, 1024), eventCh: make(chan serf.UserEvent, 1024),
eventBuf: make([]*UserEvent, 256), eventBuf: make([]*UserEvent, 256),
shutdownCh: make(chan struct{}), shutdownCh: make(chan struct{}),
endpoints: make(map[string]string),
} }
// Initialize the local state // Initialize the local state
@ -1456,3 +1463,30 @@ func (a *Agent) DisableNodeMaintenance() {
a.RemoveCheck(nodeMaintCheckID, true) a.RemoveCheck(nodeMaintCheckID, true)
a.logger.Printf("[INFO] agent: Node left maintenance mode") a.logger.Printf("[INFO] agent: Node left maintenance mode")
} }
// InjectEndpoint overrides the given endpoint with a substitute one. Note
// that not all agent methods use this mechanism, and that is should only
// be used for testing.
func (a *Agent) InjectEndpoint(endpoint string, handler interface{}) error {
if a.server == nil {
return fmt.Errorf("agent must be a server")
}
if err := a.server.InjectEndpoint(handler); err != nil {
return err
}
name := reflect.Indirect(reflect.ValueOf(handler)).Type().Name()
a.endpoints[endpoint] = name
a.logger.Printf("[WARN] agent: endpoint injected; this should only be used for testing")
return nil
}
// getEndpoint returns the endpoint name to use for the given endpoint,
// which may be overridden.
func (a *Agent) getEndpoint(endpoint string) string {
if override, ok := a.endpoints[endpoint]; ok {
return override
}
return endpoint
}

View File

@ -21,12 +21,7 @@ type preparedQueryCreateResponse struct {
// PreparedQueryGeneral handles all the general prepared query requests. // PreparedQueryGeneral handles all the general prepared query requests.
func (s *HTTPServer) PreparedQueryGeneral(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPServer) PreparedQueryGeneral(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return s.preparedQueryGeneral(preparedQueryEndpoint, resp, req) endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
}
// preparedQueryGeneral is the internal method that does the work on behalf of
// PreparedQueryGeneral. The RPC endpoint is parameterized to ease testing.
func (s *HTTPServer) preparedQueryGeneral(endpoint string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
switch req.Method { switch req.Method {
case "POST": // Create a new prepared query. case "POST": // Create a new prepared query.
args := structs.PreparedQueryRequest{ args := structs.PreparedQueryRequest{
@ -82,12 +77,6 @@ func parseLimit(req *http.Request, limit *int) error {
// PreparedQuerySpecifc handles all the prepared query requests specific to a // PreparedQuerySpecifc handles all the prepared query requests specific to a
// particular query. // particular query.
func (s *HTTPServer) PreparedQuerySpecific(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPServer) PreparedQuerySpecific(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return s.preparedQuerySpecific(preparedQueryEndpoint, resp, req)
}
// preparedQuerySpecific is the internal method that does the work on behalf of
// PreparedQuerySpecific. The RPC endpoint is parameterized to ease testing.
func (s *HTTPServer) preparedQuerySpecific(endpoint string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
id := strings.TrimPrefix(req.URL.Path, "/v1/query/") id := strings.TrimPrefix(req.URL.Path, "/v1/query/")
execute := false execute := false
if strings.HasSuffix(id, preparedQueryExecuteSuffix) { if strings.HasSuffix(id, preparedQueryExecuteSuffix) {
@ -95,6 +84,7 @@ func (s *HTTPServer) preparedQuerySpecific(endpoint string, resp http.ResponseWr
id = strings.TrimSuffix(id, preparedQueryExecuteSuffix) id = strings.TrimSuffix(id, preparedQueryExecuteSuffix)
} }
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
switch req.Method { switch req.Method {
case "GET": // Execute or retrieve a prepared query. case "GET": // Execute or retrieve a prepared query.
if execute { if execute {

View File

@ -62,7 +62,7 @@ func (m *MockPreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
func TestPreparedQuery_Create(t *testing.T) { func TestPreparedQuery_Create(t *testing.T) {
httpTest(t, func(srv *HTTPServer) { httpTest(t, func(srv *HTTPServer) {
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := srv.agent.server.InjectEndpoint(&m); err != nil { if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -126,7 +126,7 @@ func TestPreparedQuery_Create(t *testing.T) {
} }
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := srv.preparedQueryGeneral("MockPreparedQuery", resp, req) obj, err := srv.PreparedQueryGeneral(resp, req)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -146,7 +146,7 @@ func TestPreparedQuery_Create(t *testing.T) {
func TestPreparedQuery_List(t *testing.T) { func TestPreparedQuery_List(t *testing.T) {
httpTest(t, func(srv *HTTPServer) { httpTest(t, func(srv *HTTPServer) {
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := srv.agent.server.InjectEndpoint(&m); err != nil { if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -176,7 +176,7 @@ func TestPreparedQuery_List(t *testing.T) {
} }
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := srv.preparedQueryGeneral("MockPreparedQuery", resp, req) obj, err := srv.PreparedQueryGeneral(resp, req)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -196,7 +196,7 @@ func TestPreparedQuery_List(t *testing.T) {
func TestPreparedQuery_Execute(t *testing.T) { func TestPreparedQuery_Execute(t *testing.T) {
httpTest(t, func(srv *HTTPServer) { httpTest(t, func(srv *HTTPServer) {
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := srv.agent.server.InjectEndpoint(&m); err != nil { if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -230,7 +230,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
} }
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := srv.preparedQuerySpecific("MockPreparedQuery", resp, req) obj, err := srv.PreparedQuerySpecific(resp, req)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -250,7 +250,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
func TestPreparedQuery_Get(t *testing.T) { func TestPreparedQuery_Get(t *testing.T) {
httpTest(t, func(srv *HTTPServer) { httpTest(t, func(srv *HTTPServer) {
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := srv.agent.server.InjectEndpoint(&m); err != nil { if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -281,7 +281,7 @@ func TestPreparedQuery_Get(t *testing.T) {
} }
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := srv.preparedQuerySpecific("MockPreparedQuery", resp, req) obj, err := srv.PreparedQuerySpecific(resp, req)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -301,7 +301,7 @@ func TestPreparedQuery_Get(t *testing.T) {
func TestPreparedQuery_Update(t *testing.T) { func TestPreparedQuery_Update(t *testing.T) {
httpTest(t, func(srv *HTTPServer) { httpTest(t, func(srv *HTTPServer) {
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := srv.agent.server.InjectEndpoint(&m); err != nil { if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -367,7 +367,7 @@ func TestPreparedQuery_Update(t *testing.T) {
} }
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err = srv.preparedQuerySpecific("MockPreparedQuery", resp, req) _, err = srv.PreparedQuerySpecific(resp, req)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -380,7 +380,7 @@ func TestPreparedQuery_Update(t *testing.T) {
func TestPreparedQuery_Delete(t *testing.T) { func TestPreparedQuery_Delete(t *testing.T) {
httpTest(t, func(srv *HTTPServer) { httpTest(t, func(srv *HTTPServer) {
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := srv.agent.server.InjectEndpoint(&m); err != nil { if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -418,7 +418,7 @@ func TestPreparedQuery_Delete(t *testing.T) {
} }
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err = srv.preparedQuerySpecific("MockPreparedQuery", resp, req) _, err = srv.PreparedQuerySpecific(resp, req)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }