mirror of https://github.com/status-im/consul.git
Adds a slightly more flexible mock system so we can test DNS.
This commit is contained in:
parent
da20e6668b
commit
5e7523ea4b
|
@ -9,6 +9,7 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
@ -104,6 +105,11 @@ type Agent struct {
|
|||
shutdown bool
|
||||
shutdownCh chan struct{}
|
||||
shutdownLock sync.Mutex
|
||||
|
||||
// endpoints lets you override RPC endpoints for testing. Not all
|
||||
// agent methods use this, so use with care and never override
|
||||
// outside of a unit test.
|
||||
endpoints map[string]string
|
||||
}
|
||||
|
||||
// Create is used to create a new Agent. Returns
|
||||
|
@ -158,6 +164,7 @@ func Create(config *Config, logOutput io.Writer) (*Agent, error) {
|
|||
eventCh: make(chan serf.UserEvent, 1024),
|
||||
eventBuf: make([]*UserEvent, 256),
|
||||
shutdownCh: make(chan struct{}),
|
||||
endpoints: make(map[string]string),
|
||||
}
|
||||
|
||||
// Initialize the local state
|
||||
|
@ -1456,3 +1463,30 @@ func (a *Agent) DisableNodeMaintenance() {
|
|||
a.RemoveCheck(nodeMaintCheckID, true)
|
||||
a.logger.Printf("[INFO] agent: Node left maintenance mode")
|
||||
}
|
||||
|
||||
// InjectEndpoint overrides the given endpoint with a substitute one. Note
|
||||
// that not all agent methods use this mechanism, and that is should only
|
||||
// be used for testing.
|
||||
func (a *Agent) InjectEndpoint(endpoint string, handler interface{}) error {
|
||||
if a.server == nil {
|
||||
return fmt.Errorf("agent must be a server")
|
||||
}
|
||||
|
||||
if err := a.server.InjectEndpoint(handler); err != nil {
|
||||
return err
|
||||
}
|
||||
name := reflect.Indirect(reflect.ValueOf(handler)).Type().Name()
|
||||
a.endpoints[endpoint] = name
|
||||
|
||||
a.logger.Printf("[WARN] agent: endpoint injected; this should only be used for testing")
|
||||
return nil
|
||||
}
|
||||
|
||||
// getEndpoint returns the endpoint name to use for the given endpoint,
|
||||
// which may be overridden.
|
||||
func (a *Agent) getEndpoint(endpoint string) string {
|
||||
if override, ok := a.endpoints[endpoint]; ok {
|
||||
return override
|
||||
}
|
||||
return endpoint
|
||||
}
|
||||
|
|
|
@ -21,12 +21,7 @@ type preparedQueryCreateResponse struct {
|
|||
|
||||
// PreparedQueryGeneral handles all the general prepared query requests.
|
||||
func (s *HTTPServer) PreparedQueryGeneral(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
return s.preparedQueryGeneral(preparedQueryEndpoint, resp, req)
|
||||
}
|
||||
|
||||
// preparedQueryGeneral is the internal method that does the work on behalf of
|
||||
// PreparedQueryGeneral. The RPC endpoint is parameterized to ease testing.
|
||||
func (s *HTTPServer) preparedQueryGeneral(endpoint string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
|
||||
switch req.Method {
|
||||
case "POST": // Create a new prepared query.
|
||||
args := structs.PreparedQueryRequest{
|
||||
|
@ -82,12 +77,6 @@ func parseLimit(req *http.Request, limit *int) error {
|
|||
// PreparedQuerySpecifc handles all the prepared query requests specific to a
|
||||
// particular query.
|
||||
func (s *HTTPServer) PreparedQuerySpecific(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
return s.preparedQuerySpecific(preparedQueryEndpoint, resp, req)
|
||||
}
|
||||
|
||||
// preparedQuerySpecific is the internal method that does the work on behalf of
|
||||
// PreparedQuerySpecific. The RPC endpoint is parameterized to ease testing.
|
||||
func (s *HTTPServer) preparedQuerySpecific(endpoint string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
id := strings.TrimPrefix(req.URL.Path, "/v1/query/")
|
||||
execute := false
|
||||
if strings.HasSuffix(id, preparedQueryExecuteSuffix) {
|
||||
|
@ -95,6 +84,7 @@ func (s *HTTPServer) preparedQuerySpecific(endpoint string, resp http.ResponseWr
|
|||
id = strings.TrimSuffix(id, preparedQueryExecuteSuffix)
|
||||
}
|
||||
|
||||
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
|
||||
switch req.Method {
|
||||
case "GET": // Execute or retrieve a prepared query.
|
||||
if execute {
|
||||
|
|
|
@ -62,7 +62,7 @@ func (m *MockPreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
|
|||
func TestPreparedQuery_Create(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
m := MockPreparedQuery{}
|
||||
if err := srv.agent.server.InjectEndpoint(&m); err != nil {
|
||||
if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -126,7 +126,7 @@ func TestPreparedQuery_Create(t *testing.T) {
|
|||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := srv.preparedQueryGeneral("MockPreparedQuery", resp, req)
|
||||
obj, err := srv.PreparedQueryGeneral(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func TestPreparedQuery_Create(t *testing.T) {
|
|||
func TestPreparedQuery_List(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
m := MockPreparedQuery{}
|
||||
if err := srv.agent.server.InjectEndpoint(&m); err != nil {
|
||||
if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -176,7 +176,7 @@ func TestPreparedQuery_List(t *testing.T) {
|
|||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := srv.preparedQueryGeneral("MockPreparedQuery", resp, req)
|
||||
obj, err := srv.PreparedQueryGeneral(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -196,7 +196,7 @@ func TestPreparedQuery_List(t *testing.T) {
|
|||
func TestPreparedQuery_Execute(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
m := MockPreparedQuery{}
|
||||
if err := srv.agent.server.InjectEndpoint(&m); err != nil {
|
||||
if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -230,7 +230,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
|
|||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := srv.preparedQuerySpecific("MockPreparedQuery", resp, req)
|
||||
obj, err := srv.PreparedQuerySpecific(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -250,7 +250,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
|
|||
func TestPreparedQuery_Get(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
m := MockPreparedQuery{}
|
||||
if err := srv.agent.server.InjectEndpoint(&m); err != nil {
|
||||
if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -281,7 +281,7 @@ func TestPreparedQuery_Get(t *testing.T) {
|
|||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := srv.preparedQuerySpecific("MockPreparedQuery", resp, req)
|
||||
obj, err := srv.PreparedQuerySpecific(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -301,7 +301,7 @@ func TestPreparedQuery_Get(t *testing.T) {
|
|||
func TestPreparedQuery_Update(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
m := MockPreparedQuery{}
|
||||
if err := srv.agent.server.InjectEndpoint(&m); err != nil {
|
||||
if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -367,7 +367,7 @@ func TestPreparedQuery_Update(t *testing.T) {
|
|||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
_, err = srv.preparedQuerySpecific("MockPreparedQuery", resp, req)
|
||||
_, err = srv.PreparedQuerySpecific(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -380,7 +380,7 @@ func TestPreparedQuery_Update(t *testing.T) {
|
|||
func TestPreparedQuery_Delete(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
m := MockPreparedQuery{}
|
||||
if err := srv.agent.server.InjectEndpoint(&m); err != nil {
|
||||
if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -418,7 +418,7 @@ func TestPreparedQuery_Delete(t *testing.T) {
|
|||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
_, err = srv.preparedQuerySpecific("MockPreparedQuery", resp, req)
|
||||
_, err = srv.PreparedQuerySpecific(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue