diff --git a/agent/agent_endpoint.go b/agent/agent_endpoint.go index 9646f7f288..c7ec726611 100644 --- a/agent/agent_endpoint.go +++ b/agent/agent_endpoint.go @@ -1005,7 +1005,7 @@ func (s *HTTPHandlers) AgentHealthServiceByID(resp http.ResponseWriter, req *htt } notFoundReason := fmt.Sprintf("ServiceId %s not found", sid.String()) if returnTextPlain(req) { - return notFoundReason, CodeWithPayloadError{StatusCode: http.StatusNotFound, Reason: notFoundReason, ContentType: "application/json"} + return notFoundReason, CodeWithPayloadError{StatusCode: http.StatusNotFound, Reason: notFoundReason, ContentType: "text/plain"} } return &api.AgentServiceChecksInfo{ AggregatedStatus: api.HealthCritical, diff --git a/agent/agent_endpoint_test.go b/agent/agent_endpoint_test.go index 2b1087479e..3d4a80beb4 100644 --- a/agent/agent_endpoint_test.go +++ b/agent/agent_endpoint_test.go @@ -53,8 +53,8 @@ func createACLTokenWithAgentReadPolicy(t *testing.T, srv *HTTPHandlers) string { req, _ := http.NewRequest("PUT", "/v1/acl/policy?token=root", jsonReader(policyReq)) resp := httptest.NewRecorder() - _, err := srv.ACLPolicyCreate(resp, req) - require.NoError(t, err) + srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) tokenReq := &structs.ACLToken{ Description: "agent-read-token-for-test", @@ -63,10 +63,12 @@ func createACLTokenWithAgentReadPolicy(t *testing.T, srv *HTTPHandlers) string { req, _ = http.NewRequest("PUT", "/v1/acl/token?token=root", jsonReader(tokenReq)) resp = httptest.NewRecorder() - tokInf, err := srv.ACLTokenCreate(resp, req) + srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + svcToken := &structs.ACLToken{} + dec := json.NewDecoder(resp.Body) + err := dec.Decode(svcToken) require.NoError(t, err) - svcToken, ok := tokInf.(*structs.ACLToken) - require.True(t, ok) return svcToken.SecretID } @@ -283,13 +285,21 @@ func TestAgent_Services_MeshGateway(t *testing.T) { a.State.AddService(srv1, "") req, _ := http.NewRequest("GET", "/v1/agent/services", nil) - obj, err := a.srv.AgentServices(httptest.NewRecorder(), req) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + var val map[string]*api.AgentService + err := dec.Decode(&val) require.NoError(t, err) - val := obj.(map[string]*api.AgentService) + require.Len(t, val, 1) actual := val["mg-dc1-01"] require.NotNil(t, actual) require.Equal(t, api.ServiceKindMeshGateway, actual.Kind) + // Proxy.ToAPI() creates an empty Upstream list instead of keeping nil so do the same with actual. + if actual.Proxy.Upstreams == nil { + actual.Proxy.Upstreams = make([]api.Upstream, 0) + } require.Equal(t, srv1.Proxy.ToAPI(), actual.Proxy) } @@ -319,13 +329,21 @@ func TestAgent_Services_TerminatingGateway(t *testing.T) { require.NoError(t, a.State.AddService(srv1, "")) req, _ := http.NewRequest("GET", "/v1/agent/services", nil) - obj, err := a.srv.AgentServices(httptest.NewRecorder(), req) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + var val map[string]*api.AgentService + err := dec.Decode(&val) require.NoError(t, err) - val := obj.(map[string]*api.AgentService) + require.Len(t, val, 1) actual := val["tg-dc1-01"] require.NotNil(t, actual) require.Equal(t, api.ServiceKindTerminatingGateway, actual.Kind) + // Proxy.ToAPI() creates an empty Upstream list instead of keeping nil so do the same with actual. + if actual.Proxy.Upstreams == nil { + actual.Proxy.Upstreams = make([]api.Upstream, 0) + } require.Equal(t, srv1.Proxy.ToAPI(), actual.Proxy) } @@ -357,17 +375,21 @@ func TestAgent_Services_ACLFilter(t *testing.T) { } t.Run("no token", func(t *testing.T) { - require := require.New(t) + req, _ := http.NewRequest("GET", "/v1/agent/services", nil) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + var val map[string]*api.AgentService + err := dec.Decode(&val) + if err != nil { + t.Fatalf("Err: %v", err) + } - req := httptest.NewRequest("GET", "/v1/agent/services", nil) - rsp := httptest.NewRecorder() - - obj, err := a.srv.AgentServices(rsp, req) - require.NoError(err) - - val := obj.(map[string]*api.AgentService) - require.Empty(val) - require.Empty(rsp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) + if len(val) != 0 { + t.Fatalf("bad: %v", val) + } + require.Len(t, val, 0) + require.Empty(t, resp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) }) t.Run("limited token", func(t *testing.T) { @@ -380,28 +402,30 @@ func TestAgent_Services_ACLFilter(t *testing.T) { `) req := httptest.NewRequest("GET", fmt.Sprintf("/v1/agent/services?token=%s", token), nil) - rsp := httptest.NewRecorder() + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) - obj, err := a.srv.AgentServices(rsp, req) - require.NoError(err) - - val := obj.(map[string]*api.AgentService) + dec := json.NewDecoder(resp.Body) + var val map[string]*api.AgentService + if err := dec.Decode(&val); err != nil { + t.Fatalf("Err: %v", err) + } require.Len(val, 1) - require.NotEmpty(rsp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) + require.NotEmpty(resp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) }) t.Run("root token", func(t *testing.T) { - require := require.New(t) - - req := httptest.NewRequest("GET", "/v1/agent/services?token=root", nil) - rsp := httptest.NewRecorder() - - obj, err := a.srv.AgentServices(rsp, req) - require.NoError(err) - - val := obj.(map[string]*api.AgentService) - require.Len(val, 2) - require.Empty(rsp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) + req, _ := http.NewRequest("GET", "/v1/agent/services?token=root", nil) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + var val map[string]*api.AgentService + err := dec.Decode(&val) + if err != nil { + t.Fatalf("Err: %v", err) + } + require.Len(t, val, 2) + require.Empty(t, resp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) }) } @@ -552,8 +576,7 @@ func TestAgent_Service(t *testing.T) { // don't alter it and affect later test cases. req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=root", jsonReader(updatedProxy)) resp := httptest.NewRecorder() - _, err := a.srv.AgentRegisterService(resp, req) - require.NoError(t, err) + a.srv.h.ServeHTTP(resp, req) require.Equal(t, 200, resp.Code, "body: %s", resp.Body.String()) }, wantWait: 100 * time.Millisecond, @@ -586,8 +609,7 @@ func TestAgent_Service(t *testing.T) { // Re-register with _same_ proxy config req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=root", jsonReader(sidecarProxy)) resp := httptest.NewRecorder() - _, err := a.srv.AgentRegisterService(resp, req) - require.NoError(t, err) + a.srv.h.ServeHTTP(resp, req) require.Equal(t, 200, resp.Code, "body: %s", resp.Body.String()) }, wantWait: 200 * time.Millisecond, @@ -679,8 +701,7 @@ func TestAgent_Service(t *testing.T) { { req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=root", jsonReader(sidecarProxy)) resp := httptest.NewRecorder() - _, err := a.srv.AgentRegisterService(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) require.Equal(200, resp.Code, "body: %s", resp.Body.String()) } @@ -698,14 +719,11 @@ func TestAgent_Service(t *testing.T) { go tt.updateFunc() } start := time.Now() - obj, err := a.srv.AgentService(resp, req) + a.srv.h.ServeHTTP(resp, req) elapsed := time.Since(start) if tt.wantErr != "" { - require.Error(err) - require.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.wantErr)) - } else { - require.NoError(err) + require.Contains(strings.ToLower(resp.Body.String()), strings.ToLower(tt.wantErr)) } if tt.wantCode != 0 { require.Equal(tt.wantCode, resp.Code, "body: %s", resp.Body.String()) @@ -719,12 +737,13 @@ func TestAgent_Service(t *testing.T) { } if tt.wantResp != nil { - assert.Equal(tt.wantResp, obj) + dec := json.NewDecoder(resp.Body) + val := &api.AgentService{} + err := dec.Decode(&val) + require.NoError(err) + + assert.Equal(tt.wantResp, val) assert.Equal(tt.wantResp.ContentHash, resp.Header().Get("X-Consul-ContentHash")) - } else { - // Janky but Equal doesn't help here because nil != - // *api.AgentService((*api.AgentService)(nil)) - assert.Nil(obj) } }) } @@ -751,25 +770,29 @@ func TestAgent_Checks(t *testing.T) { a.State.AddCheck(chk1, "") req, _ := http.NewRequest("GET", "/v1/agent/checks", nil) - obj, err := a.srv.AgentChecks(nil, req) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + var val map[types.CheckID]*structs.HealthCheck + err := dec.Decode(&val) if err != nil { t.Fatalf("Err: %v", err) } - val := obj.(map[types.CheckID]*structs.HealthCheck) + if len(val) != 1 { - t.Fatalf("bad checks: %v", obj) + t.Fatalf("bad checks: %v", val) } if val["mysql"].Status != api.HealthPassing { - t.Fatalf("bad check: %v", obj) + t.Fatalf("bad check: %v", val) } if val["mysql"].Node != chk1.Node { - t.Fatalf("bad check: %v", obj) + t.Fatalf("bad check: %v", val) } if val["mysql"].Interval != chk1.Interval { - t.Fatalf("bad check: %v", obj) + t.Fatalf("bad check: %v", val) } if val["mysql"].Timeout != chk1.Timeout { - t.Fatalf("bad check: %v", obj) + t.Fatalf("bad check: %v", val) } } @@ -800,9 +823,13 @@ func TestAgent_ChecksWithFilter(t *testing.T) { a.State.AddCheck(chk2, "") req, _ := http.NewRequest("GET", "/v1/agent/checks?filter="+url.QueryEscape("Name == `redis`"), nil) - obj, err := a.srv.AgentChecks(nil, req) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + var val map[types.CheckID]*structs.HealthCheck + err := dec.Decode(&val) require.NoError(t, err) - val := obj.(map[types.CheckID]*structs.HealthCheck) + require.Len(t, val, 1) _, ok := val["redis"] require.True(t, ok) @@ -822,21 +849,29 @@ func TestAgent_HealthServiceByID(t *testing.T) { ID: "mysql", Service: "mysql", } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { + + serviceReq := AddServiceRequest{ + Service: service, + chkTypes: nil, + persist: false, + token: "", + Source: ConfigSourceLocal, + } + if err := a.AddService(serviceReq); err != nil { t.Fatalf("err: %v", err) } - service = &structs.NodeService{ + serviceReq.Service = &structs.NodeService{ ID: "mysql2", Service: "mysql2", } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { + if err := a.AddService(serviceReq); err != nil { t.Fatalf("err: %v", err) } - service = &structs.NodeService{ + serviceReq.Service = &structs.NodeService{ ID: "mysql3", Service: "mysql3", } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { + if err := a.AddService(serviceReq); err != nil { t.Fatalf("err: %v", err) } @@ -918,41 +953,28 @@ func TestAgent_HealthServiceByID(t *testing.T) { t.Helper() req, _ := http.NewRequest("GET", url+"?format=text", nil) resp := httptest.NewRecorder() - data, err := a.srv.AgentHealthServiceByID(resp, req) - codeWithPayload, ok := err.(CodeWithPayloadError) - if !ok { - t.Fatalf("Err: %v", err) - } - if got, want := codeWithPayload.StatusCode, expectedCode; got != want { - t.Fatalf("returned bad status: expected %d, but had: %d in %#v", expectedCode, codeWithPayload.StatusCode, codeWithPayload) - } - body, ok := data.(string) - if !ok { - t.Fatalf("Cannot get result as string in := %#v", data) + a.srv.h.ServeHTTP(resp, req) + body := resp.Body.String() + if got, want := resp.Code, expectedCode; got != want { + t.Fatalf("returned bad status: expected %d, but had: %d", expectedCode, resp.Code) } if got, want := body, expected; got != want { t.Fatalf("got body %q want %q", got, want) } - if got, want := codeWithPayload.Reason, expected; got != want { - t.Fatalf("got body %q want %q", got, want) - } }) t.Run("format=json", func(t *testing.T) { req, _ := http.NewRequest("GET", url, nil) resp := httptest.NewRecorder() - dataRaw, err := a.srv.AgentHealthServiceByID(resp, req) - codeWithPayload, ok := err.(CodeWithPayloadError) - if !ok { - t.Fatalf("Err: %v", err) + a.srv.h.ServeHTTP(resp, req) + if got, want := resp.Code, expectedCode; got != want { + t.Fatalf("returned bad status: expected %d, but had: %d", expectedCode, resp.Code) } - if got, want := codeWithPayload.StatusCode, expectedCode; got != want { - t.Fatalf("returned bad status: expected %d, but had: %d in %#v", expectedCode, codeWithPayload.StatusCode, codeWithPayload) + dec := json.NewDecoder(resp.Body) + data := &api.AgentServiceChecksInfo{} + if err := dec.Decode(data); err != nil { + t.Fatalf("Cannot convert result from JSON: %v", err) } - data, ok := dataRaw.(*api.AgentServiceChecksInfo) - if !ok { - t.Fatalf("Cannot connvert result to JSON: %#v", dataRaw) - } - if codeWithPayload.StatusCode != http.StatusNotFound { + if resp.Code != http.StatusNotFound { if data != nil && data.AggregatedStatus != expected { t.Fatalf("got body %v want %v", data, expected) } @@ -1020,42 +1042,49 @@ func TestAgent_HealthServiceByName(t *testing.T) { ID: "mysql1", Service: "mysql-pool-r", } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { + serviceReq := AddServiceRequest{ + Service: service, + chkTypes: nil, + persist: false, + token: "", + Source: ConfigSourceLocal, + } + if err := a.AddService(serviceReq); err != nil { t.Fatalf("err: %v", err) } - service = &structs.NodeService{ + serviceReq.Service = &structs.NodeService{ ID: "mysql2", Service: "mysql-pool-r", } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { + if err := a.AddService(serviceReq); err != nil { t.Fatalf("err: %v", err) } - service = &structs.NodeService{ + serviceReq.Service = &structs.NodeService{ ID: "mysql3", Service: "mysql-pool-rw", } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { + if err := a.AddService(serviceReq); err != nil { t.Fatalf("err: %v", err) } - service = &structs.NodeService{ + serviceReq.Service = &structs.NodeService{ ID: "mysql4", Service: "mysql-pool-rw", } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { + if err := a.AddService(serviceReq); err != nil { t.Fatalf("err: %v", err) } - service = &structs.NodeService{ + serviceReq.Service = &structs.NodeService{ ID: "httpd1", Service: "httpd", } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { + if err := a.AddService(serviceReq); err != nil { t.Fatalf("err: %v", err) } - service = &structs.NodeService{ + serviceReq.Service = &structs.NodeService{ ID: "httpd2", Service: "httpd", } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { + if err := a.AddService(serviceReq); err != nil { t.Fatalf("err: %v", err) } @@ -1169,18 +1198,11 @@ func TestAgent_HealthServiceByName(t *testing.T) { t.Helper() req, _ := http.NewRequest("GET", url+"?format=text", nil) resp := httptest.NewRecorder() - data, err := a.srv.AgentHealthServiceByName(resp, req) - codeWithPayload, ok := err.(CodeWithPayloadError) - if !ok { - t.Fatalf("Err: %v", err) - } - if got, want := codeWithPayload.StatusCode, expectedCode; got != want { + a.srv.h.ServeHTTP(resp, req) + if got, want := resp.Code, expectedCode; got != want { t.Fatalf("returned bad status: %d. Body: %q", resp.Code, resp.Body.String()) } - if got, want := codeWithPayload.Reason, expected; got != want { - t.Fatalf("got reason %q want %q", got, want) - } - if got, want := data, expected; got != want { + if got, want := resp.Body.String(), expected; got != want { t.Fatalf("got body %q want %q", got, want) } }) @@ -1188,21 +1210,26 @@ func TestAgent_HealthServiceByName(t *testing.T) { t.Helper() req, _ := http.NewRequest("GET", url, nil) resp := httptest.NewRecorder() - dataRaw, err := a.srv.AgentHealthServiceByName(resp, req) - codeWithPayload, ok := err.(CodeWithPayloadError) - if !ok { - t.Fatalf("Err: %v", err) + a.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + data := make([]*api.AgentServiceChecksInfo, 0) + if err := dec.Decode(&data); err != nil { + t.Fatalf("Cannot convert result from JSON: %v", err) } - data, ok := dataRaw.([]api.AgentServiceChecksInfo) - if !ok { - t.Fatalf("Cannot connvert result to JSON") - } - if got, want := codeWithPayload.StatusCode, expectedCode; got != want { + if got, want := resp.Code, expectedCode; got != want { t.Fatalf("returned bad code: %d. Body: %#v", resp.Code, data) } if resp.Code != http.StatusNotFound { - if codeWithPayload.Reason != expected { - t.Fatalf("got wrong status %#v want %#v", codeWithPayload, expected) + matched := false + for _, d := range data { + if d.AggregatedStatus == expected { + matched = true + break + } + } + + if !matched { + t.Fatalf("got wrong status, wanted %#v", expected) } } }) @@ -1267,29 +1294,36 @@ func TestAgent_HealthServicesACLEnforcement(t *testing.T) { ID: "mysql1", Service: "mysql", } - require.NoError(t, a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal)) + serviceReq := AddServiceRequest{ + Service: service, + chkTypes: nil, + persist: false, + token: "", + Source: ConfigSourceLocal, + } + require.NoError(t, a.AddService(serviceReq)) - service = &structs.NodeService{ + serviceReq.Service = &structs.NodeService{ ID: "foo1", Service: "foo", } - require.NoError(t, a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal)) + require.NoError(t, a.AddService(serviceReq)) // no token t.Run("no-token-health-by-id", func(t *testing.T) { req, err := http.NewRequest("GET", "/v1/agent/health/service/id/mysql1", nil) require.NoError(t, err) resp := httptest.NewRecorder() - _, err = a.srv.AgentHealthServiceByID(resp, req) - require.Equal(t, acl.ErrPermissionDenied, err) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("no-token-health-by-name", func(t *testing.T) { req, err := http.NewRequest("GET", "/v1/agent/health/service/name/mysql", nil) require.NoError(t, err) resp := httptest.NewRecorder() - _, err = a.srv.AgentHealthServiceByName(resp, req) - require.Equal(t, acl.ErrPermissionDenied, err) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("root-token-health-by-id", func(t *testing.T) { @@ -1297,8 +1331,8 @@ func TestAgent_HealthServicesACLEnforcement(t *testing.T) { require.NoError(t, err) req.Header.Add("X-Consul-Token", TestDefaultInitialManagementToken) resp := httptest.NewRecorder() - _, err = a.srv.AgentHealthServiceByID(resp, req) - require.NotEqual(t, acl.ErrPermissionDenied, err) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) t.Run("root-token-health-by-name", func(t *testing.T) { @@ -1306,8 +1340,8 @@ func TestAgent_HealthServicesACLEnforcement(t *testing.T) { require.NoError(t, err) req.Header.Add("X-Consul-Token", TestDefaultInitialManagementToken) resp := httptest.NewRecorder() - _, err = a.srv.AgentHealthServiceByName(resp, req) - require.NotEqual(t, acl.ErrPermissionDenied, err) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -1341,17 +1375,18 @@ func TestAgent_Checks_ACLFilter(t *testing.T) { } t.Run("no token", func(t *testing.T) { - require := require.New(t) + req, _ := http.NewRequest("GET", "/v1/agent/checks", nil) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) - req := httptest.NewRequest("GET", "/v1/agent/checks", nil) - rsp := httptest.NewRecorder() + dec := json.NewDecoder(resp.Body) + val := make(map[types.CheckID]*structs.HealthCheck) + if err := dec.Decode(&val); err != nil { + t.Fatalf("Err: %v", err) + } - obj, err := a.srv.AgentChecks(rsp, req) - require.NoError(err) - - val := obj.(map[types.CheckID]*structs.HealthCheck) - require.Empty(val) - require.Empty(rsp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) + require.Len(t, val, 0) + require.Empty(t, resp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) }) t.Run("limited token", func(t *testing.T) { @@ -1367,28 +1402,30 @@ func TestAgent_Checks_ACLFilter(t *testing.T) { `, a.Config.NodeName)) req := httptest.NewRequest("GET", fmt.Sprintf("/v1/agent/checks?token=%s", token), nil) - rsp := httptest.NewRecorder() + resp := httptest.NewRecorder() - obj, err := a.srv.AgentChecks(rsp, req) - require.NoError(err) - - val := obj.(map[types.CheckID]*structs.HealthCheck) + a.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + var val map[types.CheckID]*structs.HealthCheck + if err := dec.Decode(&val); err != nil { + t.Fatalf("Err: %v", err) + } require.Len(val, 1) - require.NotEmpty(rsp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) + require.NotEmpty(resp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) }) t.Run("root token", func(t *testing.T) { - require := require.New(t) + req, _ := http.NewRequest("GET", "/v1/agent/checks?token=root", nil) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) - req := httptest.NewRequest("GET", "/v1/agent/checks?token=root", nil) - rsp := httptest.NewRecorder() - - obj, err := a.srv.AgentChecks(rsp, req) - require.NoError(err) - - val := obj.(map[types.CheckID]*structs.HealthCheck) - require.Len(val, 2) - require.Empty(rsp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) + dec := json.NewDecoder(resp.Body) + val := make(map[types.CheckID]*structs.HealthCheck) + if err := dec.Decode(&val); err != nil { + t.Fatalf("Err: %v", err) + } + require.Len(t, val, 2) + require.Empty(t, resp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) }) } @@ -1432,12 +1469,15 @@ func TestAgent_Self(t *testing.T) { testrpc.WaitForTestAgent(t, a.RPC, "dc1") req, _ := http.NewRequest("GET", "/v1/agent/self", nil) - obj, err := a.srv.AgentSelf(nil, req) - require.NoError(t, err) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + dec := json.NewDecoder(resp.Body) + val := &Self{} + require.NoError(t, dec.Decode(val)) - val := obj.(Self) require.Equal(t, a.Config.SerfPortLAN, int(val.Member.Port)) - require.Equal(t, a.Config.SerfPortLAN, val.DebugConfig["SerfPortLAN"].(int)) + require.Equal(t, a.Config.SerfPortLAN, int(val.DebugConfig["SerfPortLAN"].(float64))) cs, err := a.GetLANCoordinate() require.NoError(t, err) @@ -1472,24 +1512,24 @@ func TestAgent_Self_ACLDeny(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("GET", "/v1/agent/self", nil) - if _, err := a.srv.AgentSelf(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("agent master token", func(t *testing.T) { req, _ := http.NewRequest("GET", "/v1/agent/self?token=towel", nil) - if _, err := a.srv.AgentSelf(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) t.Run("read-only token", func(t *testing.T) { ro := createACLTokenWithAgentReadPolicy(t, a.srv) req, _ := http.NewRequest("GET", fmt.Sprintf("/v1/agent/self?token=%s", ro), nil) - if _, err := a.srv.AgentSelf(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -1505,24 +1545,24 @@ func TestAgent_Metrics_ACLDeny(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("GET", "/v1/agent/metrics", nil) - if _, err := a.srv.AgentMetrics(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("agent master token", func(t *testing.T) { req, _ := http.NewRequest("GET", "/v1/agent/metrics?token=towel", nil) - if _, err := a.srv.AgentMetrics(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) t.Run("read-only token", func(t *testing.T) { ro := createACLTokenWithAgentReadPolicy(t, a.srv) req, _ := http.NewRequest("GET", fmt.Sprintf("/v1/agent/metrics?token=%s", ro), nil) - if _, err := a.srv.AgentMetrics(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -1854,17 +1894,17 @@ func TestAgent_Reload_ACLDeny(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/reload", nil) - if _, err := a.srv.AgentReload(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("read-only token", func(t *testing.T) { ro := createACLTokenWithAgentReadPolicy(t, a.srv) req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/agent/reload?token=%s", ro), nil) - if _, err := a.srv.AgentReload(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) // This proves we call the ACL function, and we've got the other reload @@ -1884,17 +1924,21 @@ func TestAgent_Members(t *testing.T) { testrpc.WaitForTestAgent(t, a.RPC, "dc1") req, _ := http.NewRequest("GET", "/v1/agent/members", nil) - obj, err := a.srv.AgentMembers(nil, req) - if err != nil { + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + dec := json.NewDecoder(resp.Body) + val := make([]serf.Member, 0) + if err := dec.Decode(&val); err != nil { t.Fatalf("Err: %v", err) } - val := obj.([]serf.Member) + if len(val) == 0 { - t.Fatalf("bad members: %v", obj) + t.Fatalf("bad members: %v", val) } if int(val[0].Port) != a.Config.SerfPortLAN { - t.Fatalf("not lan: %v", obj) + t.Fatalf("not lan: %v", val) } } @@ -1909,17 +1953,21 @@ func TestAgent_Members_WAN(t *testing.T) { testrpc.WaitForTestAgent(t, a.RPC, "dc1") req, _ := http.NewRequest("GET", "/v1/agent/members?wan=true", nil) - obj, err := a.srv.AgentMembers(nil, req) - if err != nil { + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + dec := json.NewDecoder(resp.Body) + val := make([]serf.Member, 0) + if err := dec.Decode(&val); err != nil { t.Fatalf("Err: %v", err) } - val := obj.([]serf.Member) + if len(val) == 0 { - t.Fatalf("bad members: %v", obj) + t.Fatalf("bad members: %v", val) } if int(val[0].Port) != a.Config.SerfPortWAN { - t.Fatalf("not wan: %v", obj) + t.Fatalf("not wan: %v", val) } } @@ -1941,21 +1989,22 @@ func TestAgent_Members_ACLFilter(t *testing.T) { testrpc.WaitForLeader(t, b.RPC, "dc1") joinPath := fmt.Sprintf("/v1/agent/join/127.0.0.1:%d?token=root", b.Config.SerfPortLAN) - _, err := a.srv.AgentJoin(nil, httptest.NewRequest(http.MethodPut, joinPath, nil)) - require.NoError(t, err) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, httptest.NewRequest(http.MethodPut, joinPath, nil)) + require.Equal(t, http.StatusOK, resp.Code) t.Run("no token", func(t *testing.T) { - require := require.New(t) + req, _ := http.NewRequest("GET", "/v1/agent/members", nil) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) - req := httptest.NewRequest("GET", "/v1/agent/members", nil) - rsp := httptest.NewRecorder() - - obj, err := a.srv.AgentMembers(rsp, req) - require.NoError(err) - - val := obj.([]serf.Member) - require.Empty(val) - require.Empty(rsp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) + dec := json.NewDecoder(resp.Body) + val := make([]serf.Member, 0) + if err := dec.Decode(&val); err != nil { + t.Fatalf("Err: %v", err) + } + require.Len(t, val, 0) + require.Empty(t, resp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) }) t.Run("limited token", func(t *testing.T) { @@ -1968,28 +2017,30 @@ func TestAgent_Members_ACLFilter(t *testing.T) { `, b.Config.NodeName)) req := httptest.NewRequest("GET", fmt.Sprintf("/v1/agent/members?token=%s", token), nil) - rsp := httptest.NewRecorder() + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) - obj, err := a.srv.AgentMembers(rsp, req) - require.NoError(err) - - val := obj.([]serf.Member) + dec := json.NewDecoder(resp.Body) + val := make([]serf.Member, 0) + if err := dec.Decode(&val); err != nil { + t.Fatalf("Err: %v", err) + } require.Len(val, 1) - require.NotEmpty(rsp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) + require.NotEmpty(resp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) }) t.Run("root token", func(t *testing.T) { - require := require.New(t) + req, _ := http.NewRequest("GET", "/v1/agent/members?token=root", nil) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) - req := httptest.NewRequest("GET", "/v1/agent/members?token=root", nil) - rsp := httptest.NewRecorder() - - obj, err := a.srv.AgentMembers(rsp, req) - require.NoError(err) - - val := obj.([]serf.Member) - require.Len(val, 2) - require.Empty(rsp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) + dec := json.NewDecoder(resp.Body) + val := make([]serf.Member, 0) + if err := dec.Decode(&val); err != nil { + t.Fatalf("Err: %v", err) + } + require.Len(t, val, 2) + require.Empty(t, resp.Header().Get("X-Consul-Results-Filtered-By-ACLs")) }) } @@ -2008,13 +2059,8 @@ func TestAgent_Join(t *testing.T) { addr := fmt.Sprintf("127.0.0.1:%d", a2.Config.SerfPortLAN) req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/agent/join/%s", addr), nil) - obj, err := a1.srv.AgentJoin(nil, req) - if err != nil { - t.Fatalf("Err: %v", err) - } - if obj != nil { - t.Fatalf("Err: %v", obj) - } + resp := httptest.NewRecorder() + a1.srv.h.ServeHTTP(resp, req) if len(a1.LANMembersInAgentPartition()) != 2 { t.Fatalf("should have 2 members") @@ -2042,13 +2088,8 @@ func TestAgent_Join_WAN(t *testing.T) { addr := fmt.Sprintf("127.0.0.1:%d", a2.Config.SerfPortWAN) req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/agent/join/%s?wan=true", addr), nil) - obj, err := a1.srv.AgentJoin(nil, req) - if err != nil { - t.Fatalf("Err: %v", err) - } - if obj != nil { - t.Fatalf("Err: %v", obj) - } + resp := httptest.NewRecorder() + a1.srv.h.ServeHTTP(resp, req) if len(a1.WANMembers()) != 2 { t.Fatalf("should have 2 members") @@ -2078,25 +2119,27 @@ func TestAgent_Join_ACLDeny(t *testing.T) { t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/agent/join/%s", addr), nil) - if _, err := a1.srv.AgentJoin(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a1.srv.h.ServeHTTP(resp, req) + + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("agent master token", func(t *testing.T) { req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/agent/join/%s?token=towel", addr), nil) - _, err := a1.srv.AgentJoin(nil, req) - if err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a1.srv.h.ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) }) t.Run("read-only token", func(t *testing.T) { ro := createACLTokenWithAgentReadPolicy(t, a1.srv) req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/agent/join/%s?token=%s", addr, ro), nil) - if _, err := a1.srv.AgentJoin(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a1.srv.h.ServeHTTP(resp, req) + + require.Equal(t, http.StatusForbidden, resp.Code) }) } @@ -2162,13 +2205,10 @@ func TestAgent_Leave(t *testing.T) { // Graceful leave now req, _ := http.NewRequest("PUT", "/v1/agent/leave", nil) - obj, err := a2.srv.AgentLeave(nil, req) - if err != nil { - t.Fatalf("Err: %v", err) - } - if obj != nil { - t.Fatalf("Err: %v", obj) - } + resp := httptest.NewRecorder() + a2.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + retry.Run(t, func(r *retry.R) { m := a1.LANMembersInAgentPartition() if got, want := m[1].Status, serf.StatusLeft; got != want { @@ -2189,26 +2229,29 @@ func TestAgent_Leave_ACLDeny(t *testing.T) { t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/leave", nil) - if _, err := a.srv.AgentLeave(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("read-only token", func(t *testing.T) { ro := createACLTokenWithAgentReadPolicy(t, a.srv) req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/agent/leave?token=%s", ro), nil) - if _, err := a.srv.AgentLeave(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + require.Equal(t, http.StatusForbidden, resp.Code) }) // this sub-test will change the state so that there is no leader. // it must therefore be the last one in this list. t.Run("agent master token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/leave?token=towel", nil) - if _, err := a.srv.AgentLeave(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -2243,13 +2286,10 @@ func TestAgent_ForceLeave(t *testing.T) { // Force leave now req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/agent/force-leave/%s", a2.Config.NodeName), nil) - obj, err := a1.srv.AgentForceLeave(nil, req) - if err != nil { - t.Fatalf("Err: %v", err) - } - if obj != nil { - t.Fatalf("Err: %v", obj) - } + resp := httptest.NewRecorder() + a1.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + retry.Run(t, func(r *retry.R) { m := a1.LANMembersInAgentPartition() if got, want := m[1].Status, serf.StatusLeft; got != want { @@ -2287,24 +2327,24 @@ func TestAgent_ForceLeave_ACLDeny(t *testing.T) { t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("PUT", uri, nil) - if _, err := a.srv.AgentForceLeave(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("agent master token", func(t *testing.T) { req, _ := http.NewRequest("PUT", uri+"?token=towel", nil) - if _, err := a.srv.AgentForceLeave(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("read-only token", func(t *testing.T) { ro := createACLTokenWithAgentReadPolicy(t, a.srv) req, _ := http.NewRequest("PUT", fmt.Sprintf(uri+"?token=%s", ro), nil) - if _, err := a.srv.AgentForceLeave(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("operator write token", func(t *testing.T) { @@ -2315,9 +2355,9 @@ func TestAgent_ForceLeave_ACLDeny(t *testing.T) { opToken := testCreateToken(t, a, rules) req, _ := http.NewRequest("PUT", fmt.Sprintf(uri+"?token=%s", opToken), nil) - if _, err := a.srv.AgentForceLeave(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -2356,13 +2396,9 @@ func TestAgent_ForceLeavePrune(t *testing.T) { // Force leave now req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/agent/force-leave/%s?prune=true", a2.Config.NodeName), nil) - obj, err := a1.srv.AgentForceLeave(nil, req) - if err != nil { - t.Fatalf("Err: %v", err) - } - if obj != nil { - t.Fatalf("Err: %v", obj) - } + resp := httptest.NewRecorder() + a1.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) retry.Run(t, func(r *retry.R) { m := len(a1.LANMembersInAgentPartition()) if m != 1 { @@ -2454,13 +2490,9 @@ func TestAgent_RegisterCheck(t *testing.T) { TTL: 15 * time.Second, } req, _ := http.NewRequest("PUT", "/v1/agent/check/register?token=abc123", jsonReader(args)) - obj, err := a.srv.AgentRegisterCheck(nil, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("bad: %v", obj) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) // Ensure we have a check mapping checkID := structs.NewCheckID("test", nil) @@ -2531,9 +2563,7 @@ func TestAgent_RegisterCheck_Scripts(t *testing.T) { t.Run(tt.name+" as node check", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/check/register", jsonReader(tt.check)) resp := httptest.NewRecorder() - if _, err := a.srv.AgentRegisterCheck(resp, req); err != nil { - t.Fatalf("err: %v", err) - } + a.srv.h.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("bad: %d", resp.Code) } @@ -2548,9 +2578,7 @@ func TestAgent_RegisterCheck_Scripts(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) resp := httptest.NewRecorder() - if _, err := a.srv.AgentRegisterService(resp, req); err != nil { - t.Fatalf("err: %v", err) - } + a.srv.h.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("bad: %d", resp.Code) } @@ -2565,9 +2593,7 @@ func TestAgent_RegisterCheck_Scripts(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) resp := httptest.NewRecorder() - if _, err := a.srv.AgentRegisterService(resp, req); err != nil { - t.Fatalf("err: %v", err) - } + a.srv.h.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("bad: %d", resp.Code) } @@ -2592,12 +2618,12 @@ func TestAgent_RegisterCheckScriptsExecDisable(t *testing.T) { } req, _ := http.NewRequest("PUT", "/v1/agent/check/register?token=abc123", jsonReader(args)) res := httptest.NewRecorder() - _, err := a.srv.AgentRegisterCheck(res, req) - if err == nil { - t.Fatalf("expected error but got nil") + a.srv.h.ServeHTTP(res, req) + if http.StatusInternalServerError != res.Code { + t.Fatalf("expected 500 code error but got %v", res.Code) } - if !strings.Contains(err.Error(), "Scripts are disabled on this agent") { - t.Fatalf("expected script disabled error, got: %s", err) + if !strings.Contains(res.Body.String(), "Scripts are disabled on this agent") { + t.Fatalf("expected script disabled error, got: %s", res.Body.String()) } checkID := structs.NewCheckID("test", nil) require.Nil(t, a.State.Check(checkID), "check registered with exec disabled") @@ -2622,12 +2648,12 @@ func TestAgent_RegisterCheckScriptsExecRemoteDisable(t *testing.T) { } req, _ := http.NewRequest("PUT", "/v1/agent/check/register?token=abc123", jsonReader(args)) res := httptest.NewRecorder() - _, err := a.srv.AgentRegisterCheck(res, req) - if err == nil { - t.Fatalf("expected error but got nil") + a.srv.h.ServeHTTP(res, req) + if http.StatusInternalServerError != res.Code { + t.Fatalf("expected 500 code error but got %v", res.Code) } - if !strings.Contains(err.Error(), "Scripts are disabled on this agent") { - t.Fatalf("expected script disabled error, got: %s", err) + if !strings.Contains(res.Body.String(), "Scripts are disabled on this agent") { + t.Fatalf("expected script disabled error, got: %s", res.Body.String()) } checkID := structs.NewCheckID("test", nil) require.Nil(t, a.State.Check(checkID), "check registered with exec disabled") @@ -2649,12 +2675,10 @@ func TestAgent_RegisterCheck_Passing(t *testing.T) { Status: api.HealthPassing, } req, _ := http.NewRequest("PUT", "/v1/agent/check/register", jsonReader(args)) - obj, err := a.srv.AgentRegisterCheck(nil, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("bad: %v", obj) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + if http.StatusOK != resp.Code { + t.Fatalf("expcted 200 but got %v", resp.Code) } // Ensure we have a check mapping @@ -2691,8 +2715,9 @@ func TestAgent_RegisterCheck_BadStatus(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/check/register", jsonReader(args)) resp := httptest.NewRecorder() a.srv.h.ServeHTTP(resp, req) - require.Equalf(t, http.StatusBadRequest, resp.Code, "resp: %v", resp.Body.String()) - require.Contains(t, resp.Body.String(), "Bad check status") + if resp.Code != http.StatusBadRequest { + t.Fatalf("accepted bad status") + } } func TestAgent_RegisterCheck_ACLDeny(t *testing.T) { @@ -2725,8 +2750,8 @@ func TestAgent_RegisterCheck_ACLDeny(t *testing.T) { // ensure the service is ready for registering a check for it. req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=root", jsonReader(svc)) resp := httptest.NewRecorder() - _, err := a.srv.AgentRegisterService(resp, req) - require.NoError(t, err) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) // create a policy that has write on service foo policyReq := &structs.ACLPolicy{ @@ -2736,8 +2761,8 @@ func TestAgent_RegisterCheck_ACLDeny(t *testing.T) { req, _ = http.NewRequest("PUT", "/v1/acl/policy?token=root", jsonReader(policyReq)) resp = httptest.NewRecorder() - _, err = a.srv.ACLPolicyCreate(resp, req) - require.NoError(t, err) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) // create a policy that has write on the node name of the agent policyReq = &structs.ACLPolicy{ @@ -2747,8 +2772,8 @@ func TestAgent_RegisterCheck_ACLDeny(t *testing.T) { req, _ = http.NewRequest("PUT", "/v1/acl/policy?token=root", jsonReader(policyReq)) resp = httptest.NewRecorder() - _, err = a.srv.ACLPolicyCreate(resp, req) - require.NoError(t, err) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) // create a token using the write-foo policy tokenReq := &structs.ACLToken{ @@ -2762,10 +2787,14 @@ func TestAgent_RegisterCheck_ACLDeny(t *testing.T) { req, _ = http.NewRequest("PUT", "/v1/acl/token?token=root", jsonReader(tokenReq)) resp = httptest.NewRecorder() - tokInf, err := a.srv.ACLTokenCreate(resp, req) - require.NoError(t, err) - svcToken, ok := tokInf.(*structs.ACLToken) - require.True(t, ok) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + dec := json.NewDecoder(resp.Body) + svcToken := &structs.ACLToken{} + if err := dec.Decode(svcToken); err != nil { + t.Fatalf("err: %v", err) + } require.NotNil(t, svcToken) // create a token using the write-node policy @@ -2780,57 +2809,67 @@ func TestAgent_RegisterCheck_ACLDeny(t *testing.T) { req, _ = http.NewRequest("PUT", "/v1/acl/token?token=root", jsonReader(tokenReq)) resp = httptest.NewRecorder() - tokInf, err = a.srv.ACLTokenCreate(resp, req) - require.NoError(t, err) - nodeToken, ok := tokInf.(*structs.ACLToken) - require.True(t, ok) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + dec = json.NewDecoder(resp.Body) + nodeToken := &structs.ACLToken{} + if err := dec.Decode(nodeToken); err != nil { + t.Fatalf("err: %v", err) + } require.NotNil(t, nodeToken) t.Run("no token - node check", func(t *testing.T) { retry.Run(t, func(r *retry.R) { req, _ := http.NewRequest("PUT", "/v1/agent/check/register", jsonReader(nodeCheck)) - _, err := a.srv.AgentRegisterCheck(nil, req) - require.True(r, acl.IsErrPermissionDenied(err)) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) }) t.Run("svc token - node check", func(t *testing.T) { retry.Run(t, func(r *retry.R) { req, _ := http.NewRequest("PUT", "/v1/agent/check/register?token="+svcToken.SecretID, jsonReader(nodeCheck)) - _, err := a.srv.AgentRegisterCheck(nil, req) - require.True(r, acl.IsErrPermissionDenied(err)) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) }) t.Run("node token - node check", func(t *testing.T) { retry.Run(t, func(r *retry.R) { req, _ := http.NewRequest("PUT", "/v1/agent/check/register?token="+nodeToken.SecretID, jsonReader(nodeCheck)) - _, err := a.srv.AgentRegisterCheck(nil, req) - require.NoError(r, err) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) }) t.Run("no token - svc check", func(t *testing.T) { retry.Run(t, func(r *retry.R) { req, _ := http.NewRequest("PUT", "/v1/agent/check/register", jsonReader(svcCheck)) - _, err := a.srv.AgentRegisterCheck(nil, req) - require.True(r, acl.IsErrPermissionDenied(err)) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) }) t.Run("node token - svc check", func(t *testing.T) { retry.Run(t, func(r *retry.R) { req, _ := http.NewRequest("PUT", "/v1/agent/check/register?token="+nodeToken.SecretID, jsonReader(svcCheck)) - _, err := a.srv.AgentRegisterCheck(nil, req) - require.True(r, acl.IsErrPermissionDenied(err)) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) }) t.Run("svc token - svc check", func(t *testing.T) { retry.Run(t, func(r *retry.R) { req, _ := http.NewRequest("PUT", "/v1/agent/check/register?token="+svcToken.SecretID, jsonReader(svcCheck)) - _, err := a.srv.AgentRegisterCheck(nil, req) - require.NoError(r, err) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) }) } @@ -2851,12 +2890,10 @@ func TestAgent_DeregisterCheck(t *testing.T) { } req, _ := http.NewRequest("PUT", "/v1/agent/check/deregister/test", nil) - obj, err := a.srv.AgentDeregisterCheck(nil, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("bad: %v", obj) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + if http.StatusOK != resp.Code { + t.Fatalf("expected 200 but got %v", resp.Code) } // Ensure we have a check mapping @@ -2880,16 +2917,16 @@ func TestAgent_DeregisterCheckACLDeny(t *testing.T) { t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/check/deregister/test", nil) - if _, err := a.srv.AgentDeregisterCheck(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("root token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/check/deregister/test?token=root", nil) - if _, err := a.srv.AgentDeregisterCheck(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -2910,12 +2947,11 @@ func TestAgent_PassCheck(t *testing.T) { } req, _ := http.NewRequest("PUT", "/v1/agent/check/pass/test", nil) - obj, err := a.srv.AgentCheckPass(nil, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("bad: %v", obj) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + if http.StatusOK != resp.Code { + t.Fatalf("expected 200 by got %v", resp.Code) } // Ensure we have a check mapping @@ -2943,16 +2979,16 @@ func TestAgent_PassCheck_ACLDeny(t *testing.T) { t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/check/pass/test", nil) - if _, err := a.srv.AgentCheckPass(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("root token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/check/pass/test?token=root", nil) - if _, err := a.srv.AgentCheckPass(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -2973,12 +3009,11 @@ func TestAgent_WarnCheck(t *testing.T) { } req, _ := http.NewRequest("PUT", "/v1/agent/check/warn/test", nil) - obj, err := a.srv.AgentCheckWarn(nil, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("bad: %v", obj) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + if http.StatusOK != resp.Code { + t.Fatalf("expected 200 by got %v", resp.Code) } // Ensure we have a check mapping @@ -3006,16 +3041,16 @@ func TestAgent_WarnCheck_ACLDeny(t *testing.T) { t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/check/warn/test", nil) - if _, err := a.srv.AgentCheckWarn(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("root token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/check/warn/test?token=root", nil) - if _, err := a.srv.AgentCheckWarn(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -3036,12 +3071,11 @@ func TestAgent_FailCheck(t *testing.T) { } req, _ := http.NewRequest("PUT", "/v1/agent/check/fail/test", nil) - obj, err := a.srv.AgentCheckFail(nil, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("bad: %v", obj) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + if http.StatusOK != resp.Code { + t.Fatalf("expected 200 by got %v", resp.Code) } // Ensure we have a check mapping @@ -3069,16 +3103,16 @@ func TestAgent_FailCheck_ACLDeny(t *testing.T) { t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/check/fail/test", nil) - if _, err := a.srv.AgentCheckFail(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("root token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/check/fail/test?token=root", nil) - if _, err := a.srv.AgentCheckFail(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -3109,14 +3143,8 @@ func TestAgent_UpdateCheck(t *testing.T) { t.Run(c.Status, func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/check/update/test", jsonReader(c)) resp := httptest.NewRecorder() - obj, err := a.srv.AgentCheckUpdate(resp, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("bad: %v", obj) - } - if resp.Code != 200 { + a.srv.h.ServeHTTP(resp, req) + if resp.Code != http.StatusOK { t.Fatalf("expected 200, got %d", resp.Code) } @@ -3134,14 +3162,8 @@ func TestAgent_UpdateCheck(t *testing.T) { } req, _ := http.NewRequest("PUT", "/v1/agent/check/update/test", jsonReader(args)) resp := httptest.NewRecorder() - obj, err := a.srv.AgentCheckUpdate(resp, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("bad: %v", obj) - } - if resp.Code != 200 { + a.srv.h.ServeHTTP(resp, req) + if resp.Code != http.StatusOK { t.Fatalf("expected 200, got %d", resp.Code) } @@ -3158,14 +3180,8 @@ func TestAgent_UpdateCheck(t *testing.T) { args := checkUpdate{Status: "itscomplicated"} req, _ := http.NewRequest("PUT", "/v1/agent/check/update/test", jsonReader(args)) resp := httptest.NewRecorder() - obj, err := a.srv.AgentCheckUpdate(resp, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("bad: %v", obj) - } - if resp.Code != 400 { + a.srv.h.ServeHTTP(resp, req) + if resp.Code != http.StatusBadRequest { t.Fatalf("expected 400, got %d", resp.Code) } }) @@ -3190,17 +3206,17 @@ func TestAgent_UpdateCheck_ACLDeny(t *testing.T) { t.Run("no token", func(t *testing.T) { args := checkUpdate{api.HealthPassing, "hello-passing"} req, _ := http.NewRequest("PUT", "/v1/agent/check/update/test", jsonReader(args)) - if _, err := a.srv.AgentCheckUpdate(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("root token", func(t *testing.T) { args := checkUpdate{api.HealthPassing, "hello-passing"} req, _ := http.NewRequest("PUT", "/v1/agent/check/update/test?token=root", jsonReader(args)) - if _, err := a.srv.AgentCheckUpdate(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -3248,13 +3264,10 @@ func testAgent_RegisterService(t *testing.T, extraHCL string) { }, } req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=abc123", jsonReader(args)) - - obj, err := a.srv.AgentRegisterService(nil, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("bad: %v", obj) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + if http.StatusOK != resp.Code { + t.Fatalf("expected 200 but got %v", resp.Code) } // Ensure the service @@ -3337,8 +3350,9 @@ func testAgent_RegisterService_ReRegister(t *testing.T, extraHCL string) { }, } req, _ := http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) - _, err := a.srv.AgentRegisterService(nil, req) - require.NoError(t, err) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) args = &structs.ServiceDefinition{ Name: "test", @@ -3361,8 +3375,9 @@ func testAgent_RegisterService_ReRegister(t *testing.T, extraHCL string) { }, } req, _ = http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) - _, err = a.srv.AgentRegisterService(nil, req) - require.NoError(t, err) + resp = httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) checks := a.State.Checks(structs.DefaultEnterpriseMetaInDefaultPartition()) require.Equal(t, 3, len(checks)) @@ -3417,8 +3432,9 @@ func testAgent_RegisterService_ReRegister_ReplaceExistingChecks(t *testing.T, ex }, } req, _ := http.NewRequest("PUT", "/v1/agent/service/register?replace-existing-checks", jsonReader(args)) - _, err := a.srv.AgentRegisterService(nil, req) - require.NoError(t, err) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) args = &structs.ServiceDefinition{ Name: "test", @@ -3440,8 +3456,9 @@ func testAgent_RegisterService_ReRegister_ReplaceExistingChecks(t *testing.T, ex }, } req, _ = http.NewRequest("PUT", "/v1/agent/service/register?replace-existing-checks", jsonReader(args)) - _, err = a.srv.AgentRegisterService(nil, req) - require.NoError(t, err) + resp = httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) checks := a.State.Checks(structs.DefaultEnterpriseMetaInDefaultPartition()) require.Len(t, checks, 2) @@ -3569,9 +3586,7 @@ func testAgent_RegisterService_TranslateKeys(t *testing.T, extraHCL string) { req, _ := http.NewRequest("PUT", "/v1/agent/service/register", strings.NewReader(json)) rr := httptest.NewRecorder() - obj, err := a.srv.AgentRegisterService(rr, req) - require.NoError(t, err) - require.Nil(t, obj) + a.srv.h.ServeHTTP(rr, req) require.Equal(t, 200, rr.Code, "body: %s", rr.Body) svc := &structs.NodeService{ @@ -3721,16 +3736,16 @@ func testAgent_RegisterService_ACLDeny(t *testing.T, extraHCL string) { t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) - if _, err := a.srv.AgentRegisterService(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("root token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=root", jsonReader(args)) - if _, err := a.srv.AgentRegisterService(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -3765,10 +3780,7 @@ func testAgent_RegisterService_InvalidAddress(t *testing.T, extraHCL string) { } req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=abc123", jsonReader(args)) resp := httptest.NewRecorder() - _, err := a.srv.AgentRegisterService(resp, req) - if err != nil { - t.Fatalf("got error %v want nil", err) - } + a.srv.h.ServeHTTP(resp, req) if got, want := resp.Code, 400; got != want { t.Fatalf("got code %d want %d", got, want) } @@ -3833,9 +3845,8 @@ func testAgent_RegisterService_UnmanagedConnectProxy(t *testing.T, extraHCL stri req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=abc123", jsonReader(args)) resp := httptest.NewRecorder() - obj, err := a.srv.AgentRegisterService(resp, req) - require.NoError(t, err) - require.Nil(t, obj) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) // Ensure the service sid := structs.NewServiceID("connect-proxy", nil) @@ -3912,10 +3923,14 @@ func testCreateToken(t *testing.T, a *TestAgent, rules string) string { } req, _ := http.NewRequest("PUT", "/v1/acl/token?token=root", jsonReader(args)) resp := httptest.NewRecorder() - obj, err := a.srv.ACLTokenCreate(resp, req) - require.NoError(t, err) - require.NotNil(t, obj) - aclResp := obj.(*structs.ACLToken) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + dec := json.NewDecoder(resp.Body) + aclResp := &structs.ACLToken{} + if err := dec.Decode(aclResp); err != nil { + t.Fatalf("err: %v", err) + } return aclResp.SecretID } @@ -3926,10 +3941,14 @@ func testCreatePolicy(t *testing.T, a *TestAgent, name, rules string) string { } req, _ := http.NewRequest("PUT", "/v1/acl/policy?token=root", jsonReader(args)) resp := httptest.NewRecorder() - obj, err := a.srv.ACLPolicyCreate(resp, req) - require.NoError(t, err) - require.NotNil(t, obj) - aclResp := obj.(*structs.ACLPolicy) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + dec := json.NewDecoder(resp.Body) + aclResp := &structs.ACLPolicy{} + if err := dec.Decode(aclResp); err != nil { + t.Fatalf("err: %v", err) + } return aclResp.ID } @@ -4361,15 +4380,11 @@ func testAgent_RegisterServiceDeregisterService_Sidecar(t *testing.T, extraHCL s req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token="+token, br) resp := httptest.NewRecorder() - obj, err := a.srv.AgentRegisterService(resp, req) + a.srv.h.ServeHTTP(resp, req) if tt.wantErr != "" { - require.Error(err, "response code=%d, body:\n%s", - resp.Code, resp.Body.String()) - require.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.wantErr)) + require.Contains(strings.ToLower(resp.Body.String()), strings.ToLower(tt.wantErr)) return } - require.NoError(err) - assert.Nil(obj) require.Equal(200, resp.Code, "request failed with body: %s", resp.Body.String()) @@ -4378,7 +4393,7 @@ func testAgent_RegisterServiceDeregisterService_Sidecar(t *testing.T, extraHCL s // Parse the expected definition into a ServiceDefinition var sd structs.ServiceDefinition - err = json.Unmarshal([]byte(tt.json), &sd) + err := json.Unmarshal([]byte(tt.json), &sd) require.NoError(err) require.NotEmpty(sd.Name) @@ -4419,9 +4434,8 @@ func testAgent_RegisterServiceDeregisterService_Sidecar(t *testing.T, extraHCL s req := httptest.NewRequest("PUT", "/v1/agent/service/deregister/"+svcID+"?token="+token, nil) resp := httptest.NewRecorder() - obj, err := a.srv.AgentDeregisterService(resp, req) - require.NoError(err) - require.Nil(obj) + a.srv.h.ServeHTTP(resp, req) + require.Equal(http.StatusOK, resp.Code) svcs := a.State.AllServices() _, ok = svcs[structs.NewServiceID(tt.wantNS.ID, nil)] @@ -4474,9 +4488,7 @@ func testAgent_RegisterService_UnmanagedConnectProxyInvalid(t *testing.T, extraH req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=abc123", jsonReader(args)) resp := httptest.NewRecorder() - obj, err := a.srv.AgentRegisterService(resp, req) - assert.Nil(err) - assert.Nil(obj) + a.srv.h.ServeHTTP(resp, req) assert.Equal(http.StatusBadRequest, resp.Code) assert.Contains(resp.Body.String(), "Port") @@ -4524,9 +4536,8 @@ func testAgent_RegisterService_ConnectNative(t *testing.T, extraHCL string) { req, _ := http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) resp := httptest.NewRecorder() - obj, err := a.srv.AgentRegisterService(resp, req) - assert.Nil(err) - assert.Nil(obj) + a.srv.h.ServeHTTP(resp, req) + assert.Equal(http.StatusOK, resp.Code) // Ensure the service svc := a.State.Service(structs.NewServiceID("web", nil)) @@ -4572,13 +4583,13 @@ func testAgent_RegisterService_ScriptCheck_ExecDisable(t *testing.T, extraHCL st }, } req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=abc123", jsonReader(args)) - - _, err := a.srv.AgentRegisterService(nil, req) - if err == nil { - t.Fatalf("expected error but got nil") + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + if http.StatusInternalServerError != resp.Code { + t.Fatalf("expected 500 but got %v", resp.Code) } - if !strings.Contains(err.Error(), "Scripts are disabled on this agent") { - t.Fatalf("expected script disabled error, got: %s", err) + if !strings.Contains(resp.Body.String(), "Scripts are disabled on this agent") { + t.Fatalf("expected script disabled error, got: %s", resp.Body.String()) } checkID := types.CheckID("test-check") require.Nil(t, a.State.Check(structs.NewCheckID(checkID, nil)), "check registered with exec disabled") @@ -4624,13 +4635,13 @@ func testAgent_RegisterService_ScriptCheck_ExecRemoteDisable(t *testing.T, extra }, } req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=abc123", jsonReader(args)) - - _, err := a.srv.AgentRegisterService(nil, req) - if err == nil { - t.Fatalf("expected error but got nil") + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + if http.StatusInternalServerError != resp.Code { + t.Fatalf("expected 500 but got %v", resp.Code) } - if !strings.Contains(err.Error(), "Scripts are disabled on this agent") { - t.Fatalf("expected script disabled error, got: %s", err) + if !strings.Contains(resp.Body.String(), "Scripts are disabled on this agent") { + t.Fatalf("expected script disabled error, got: %s", resp.Body.String()) } checkID := types.CheckID("test-check") require.Nil(t, a.State.Check(structs.NewCheckID(checkID, nil)), "check registered with exec disabled") @@ -4646,21 +4657,23 @@ func TestAgent_DeregisterService(t *testing.T) { defer a.Shutdown() testrpc.WaitForTestAgent(t, a.RPC, "dc1") - service := &structs.NodeService{ - ID: "test", - Service: "test", - } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { - t.Fatalf("err: %v", err) + serviceReq := AddServiceRequest{ + Service: &structs.NodeService{ + ID: "test", + Service: "test", + }, + chkTypes: nil, + persist: false, + token: "", + Source: ConfigSourceLocal, } + require.NoError(t, a.AddService(serviceReq)) req, _ := http.NewRequest("PUT", "/v1/agent/service/deregister/test", nil) - obj, err := a.srv.AgentDeregisterService(nil, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("bad: %v", obj) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + if http.StatusOK != resp.Code { + t.Fatalf("expected 200 but got %v", resp.Code) } // Ensure we have a check mapping @@ -4678,26 +4691,30 @@ func TestAgent_DeregisterService_ACLDeny(t *testing.T) { defer a.Shutdown() testrpc.WaitForLeader(t, a.RPC, "dc1") - service := &structs.NodeService{ - ID: "test", - Service: "test", - } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { - t.Fatalf("err: %v", err) + serviceReq := AddServiceRequest{ + Service: &structs.NodeService{ + ID: "test", + Service: "test", + }, + chkTypes: nil, + persist: false, + token: "", + Source: ConfigSourceLocal, } + require.NoError(t, a.AddService(serviceReq)) t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/deregister/test", nil) - if _, err := a.srv.AgentDeregisterService(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("root token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/deregister/test?token=root", nil) - if _, err := a.srv.AgentDeregisterService(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -4714,9 +4731,7 @@ func TestAgent_ServiceMaintenance_BadRequest(t *testing.T) { t.Run("not enabled", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/maintenance/test", nil) resp := httptest.NewRecorder() - if _, err := a.srv.AgentServiceMaintenance(resp, req); err != nil { - t.Fatalf("err: %s", err) - } + a.srv.h.ServeHTTP(resp, req) if resp.Code != 400 { t.Fatalf("expected 400, got %d", resp.Code) } @@ -4725,9 +4740,7 @@ func TestAgent_ServiceMaintenance_BadRequest(t *testing.T) { t.Run("no service id", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/maintenance/?enable=true", nil) resp := httptest.NewRecorder() - if _, err := a.srv.AgentServiceMaintenance(resp, req); err != nil { - t.Fatalf("err: %s", err) - } + a.srv.h.ServeHTTP(resp, req) if resp.Code != 400 { t.Fatalf("expected 400, got %d", resp.Code) } @@ -4753,20 +4766,22 @@ func TestAgent_ServiceMaintenance_Enable(t *testing.T) { testrpc.WaitForTestAgent(t, a.RPC, "dc1") // Register the service - service := &structs.NodeService{ - ID: "test", - Service: "test", - } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { - t.Fatalf("err: %v", err) + serviceReq := AddServiceRequest{ + Service: &structs.NodeService{ + ID: "test", + Service: "test", + }, + chkTypes: nil, + persist: false, + token: "", + Source: ConfigSourceLocal, } + require.NoError(t, a.AddService(serviceReq)) // Force the service into maintenance mode req, _ := http.NewRequest("PUT", "/v1/agent/service/maintenance/test?enable=true&reason=broken&token=mytoken", nil) resp := httptest.NewRecorder() - if _, err := a.srv.AgentServiceMaintenance(resp, req); err != nil { - t.Fatalf("err: %s", err) - } + a.srv.h.ServeHTTP(resp, req) if resp.Code != 200 { t.Fatalf("expected 200, got %d", resp.Code) } @@ -4800,13 +4815,17 @@ func TestAgent_ServiceMaintenance_Disable(t *testing.T) { testrpc.WaitForTestAgent(t, a.RPC, "dc1") // Register the service - service := &structs.NodeService{ - ID: "test", - Service: "test", - } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { - t.Fatalf("err: %v", err) + serviceReq := AddServiceRequest{ + Service: &structs.NodeService{ + ID: "test", + Service: "test", + }, + chkTypes: nil, + persist: false, + token: "", + Source: ConfigSourceLocal, } + require.NoError(t, a.AddService(serviceReq)) // Force the service into maintenance mode if err := a.EnableServiceMaintenance(structs.NewServiceID("test", nil), "", ""); err != nil { @@ -4816,9 +4835,7 @@ func TestAgent_ServiceMaintenance_Disable(t *testing.T) { // Leave maintenance mode req, _ := http.NewRequest("PUT", "/v1/agent/service/maintenance/test?enable=false", nil) resp := httptest.NewRecorder() - if _, err := a.srv.AgentServiceMaintenance(resp, req); err != nil { - t.Fatalf("err: %s", err) - } + a.srv.h.ServeHTTP(resp, req) if resp.Code != 200 { t.Fatalf("expected 200, got %d", resp.Code) } @@ -4841,26 +4858,30 @@ func TestAgent_ServiceMaintenance_ACLDeny(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") // Register the service. - service := &structs.NodeService{ - ID: "test", - Service: "test", - } - if err := a.addServiceFromSource(service, nil, false, "", ConfigSourceLocal); err != nil { - t.Fatalf("err: %v", err) + serviceReq := AddServiceRequest{ + Service: &structs.NodeService{ + ID: "test", + Service: "test", + }, + chkTypes: nil, + persist: false, + token: "", + Source: ConfigSourceLocal, } + require.NoError(t, a.AddService(serviceReq)) t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/maintenance/test?enable=true&reason=broken", nil) - if _, err := a.srv.AgentServiceMaintenance(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("root token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/maintenance/test?enable=true&reason=broken&token=root", nil) - if _, err := a.srv.AgentServiceMaintenance(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -4877,9 +4898,7 @@ func TestAgent_NodeMaintenance_BadRequest(t *testing.T) { // Fails when no enable flag provided req, _ := http.NewRequest("PUT", "/v1/agent/maintenance", nil) resp := httptest.NewRecorder() - if _, err := a.srv.AgentNodeMaintenance(resp, req); err != nil { - t.Fatalf("err: %s", err) - } + a.srv.h.ServeHTTP(resp, req) if resp.Code != 400 { t.Fatalf("expected 400, got %d", resp.Code) } @@ -4898,9 +4917,7 @@ func TestAgent_NodeMaintenance_Enable(t *testing.T) { // Force the node into maintenance mode req, _ := http.NewRequest("PUT", "/v1/agent/maintenance?enable=true&reason=broken&token=mytoken", nil) resp := httptest.NewRecorder() - if _, err := a.srv.AgentNodeMaintenance(resp, req); err != nil { - t.Fatalf("err: %s", err) - } + a.srv.h.ServeHTTP(resp, req) if resp.Code != 200 { t.Fatalf("expected 200, got %d", resp.Code) } @@ -4938,9 +4955,7 @@ func TestAgent_NodeMaintenance_Disable(t *testing.T) { // Leave maintenance mode req, _ := http.NewRequest("PUT", "/v1/agent/maintenance?enable=false", nil) resp := httptest.NewRecorder() - if _, err := a.srv.AgentNodeMaintenance(resp, req); err != nil { - t.Fatalf("err: %s", err) - } + a.srv.h.ServeHTTP(resp, req) if resp.Code != 200 { t.Fatalf("expected 200, got %d", resp.Code) } @@ -4963,16 +4978,16 @@ func TestAgent_NodeMaintenance_ACLDeny(t *testing.T) { t.Run("no token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/maintenance?enable=true&reason=broken", nil) - if _, err := a.srv.AgentNodeMaintenance(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusForbidden, resp.Code) }) t.Run("root token", func(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/maintenance?enable=true&reason=broken&token=root", nil) - if _, err := a.srv.AgentNodeMaintenance(nil, req); err != nil { - t.Fatalf("err: %v", err) - } + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) }) } @@ -4996,8 +5011,10 @@ func TestAgent_RegisterCheck_Service(t *testing.T) { // First register the service req, _ := http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) - if _, err := a.srv.AgentRegisterService(nil, req); err != nil { - t.Fatalf("err: %v", err) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + if resp.Code != 200 { + t.Fatalf("expected 200, got %d", resp.Code) } // Now register an additional check @@ -5007,8 +5024,10 @@ func TestAgent_RegisterCheck_Service(t *testing.T) { TTL: 15 * time.Second, } req, _ = http.NewRequest("PUT", "/v1/agent/check/register", jsonReader(checkArgs)) - if _, err := a.srv.AgentRegisterCheck(nil, req); err != nil { - t.Fatalf("err: %v", err) + resp = httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + if resp.Code != 200 { + t.Fatalf("expected 200, got %d", resp.Code) } // Ensure we have a check mapping @@ -5045,20 +5064,14 @@ func TestAgent_Monitor(t *testing.T) { // Try passing an invalid log level req, _ := http.NewRequest("GET", "/v1/agent/monitor?loglevel=invalid", nil) resp := httptest.NewRecorder() - _, err := a.srv.AgentMonitor(resp, req) - if err == nil { - t.Fatal("expected BadRequestError to have occurred, got nil") - } - - // Note that BadRequestError is handled outside the endpoint handler so we - // still see a 200 if we check here. - if _, ok := err.(BadRequestError); !ok { - t.Fatalf("expected BadRequestError to have occurred, got %#v", err) + a.srv.h.ServeHTTP(resp, req) + if http.StatusBadRequest != resp.Code { + t.Fatalf("expected 400 but got %v", resp.Code) } substring := "Unknown log level" - if !strings.Contains(err.Error(), substring) { - t.Fatalf("got: %s, wanted message containing: %s", err.Error(), substring) + if !strings.Contains(resp.Body.String(), substring) { + t.Fatalf("got: %s, wanted message containing: %s", resp.Body.String(), substring) } }) @@ -5070,10 +5083,10 @@ func TestAgent_Monitor(t *testing.T) { req = req.WithContext(cancelCtx) resp := httptest.NewRecorder() - errCh := make(chan error) + codeCh := make(chan int) go func() { - _, err := a.srv.AgentMonitor(resp, req) - errCh <- err + a.srv.h.ServeHTTP(resp, req) + codeCh <- resp.Code }() args := &structs.ServiceDefinition{ @@ -5085,8 +5098,10 @@ func TestAgent_Monitor(t *testing.T) { } registerReq, _ := http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) - if _, err := a.srv.AgentRegisterService(nil, registerReq); err != nil { - t.Fatalf("err: %v", err) + res := httptest.NewRecorder() + a.srv.h.ServeHTTP(res, registerReq) + if http.StatusOK != res.Code { + t.Fatalf("expected 200 but got %v", res.Code) } // Wait until we have received some type of logging output @@ -5095,9 +5110,8 @@ func TestAgent_Monitor(t *testing.T) { }, 3*time.Second, 100*time.Millisecond) cancelFunc() - err := <-errCh - require.NoError(t, err) - + code := <-codeCh + require.Equal(t, http.StatusOK, code) got := resp.Body.String() // Only check a substring that we are highly confident in finding @@ -5134,8 +5148,10 @@ func TestAgent_Monitor(t *testing.T) { } registerReq, _ := http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) - if _, err := a.srv.AgentRegisterService(nil, registerReq); err != nil { - t.Fatalf("err: %v", err) + res := httptest.NewRecorder() + a.srv.h.ServeHTTP(res, registerReq) + if http.StatusOK != res.Code { + t.Fatalf("expected 200 but got %v", res.Code) } // Wait until we have received some type of logging output @@ -5154,10 +5170,10 @@ func TestAgent_Monitor(t *testing.T) { req = req.WithContext(cancelCtx) resp := httptest.NewRecorder() - errCh := make(chan error) + codeCh := make(chan int) go func() { - _, err := a.srv.AgentMonitor(resp, req) - errCh <- err + a.srv.h.ServeHTTP(resp, req) + codeCh <- resp.Code }() args := &structs.ServiceDefinition{ @@ -5169,8 +5185,10 @@ func TestAgent_Monitor(t *testing.T) { } registerReq, _ := http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) - if _, err := a.srv.AgentRegisterService(nil, registerReq); err != nil { - t.Fatalf("err: %v", err) + res := httptest.NewRecorder() + a.srv.h.ServeHTTP(res, registerReq) + if http.StatusOK != res.Code { + t.Fatalf("expected 200 but got %v", res.Code) } // Wait until we have received some type of logging output @@ -5179,8 +5197,8 @@ func TestAgent_Monitor(t *testing.T) { }, 3*time.Second, 100*time.Millisecond) cancelFunc() - err := <-errCh - require.NoError(t, err) + code := <-codeCh + require.Equal(t, http.StatusOK, code) // Each line is output as a separate JSON object, we grab the first and // make sure it can be unmarshalled. @@ -5199,12 +5217,12 @@ func TestAgent_Monitor(t *testing.T) { req = req.WithContext(cancelCtx) resp := httptest.NewRecorder() - chErr := make(chan error) + codeCh := make(chan int) chStarted := make(chan struct{}) go func() { close(chStarted) - _, err := a.srv.AgentMonitor(resp, req) - chErr <- err + a.srv.h.ServeHTTP(resp, req) + codeCh <- resp.Code }() <-chStarted @@ -5216,8 +5234,8 @@ func TestAgent_Monitor(t *testing.T) { }, 3*time.Second, 100*time.Millisecond) cancelFunc() - err := <-chErr - require.NoError(t, err) + code := <-codeCh + require.Equal(t, http.StatusOK, code) got := resp.Body.String() want := "serf: Shutdown without a Leave" @@ -5239,8 +5257,10 @@ func TestAgent_Monitor_ACLDeny(t *testing.T) { // Try without a token. req, _ := http.NewRequest("GET", "/v1/agent/monitor", nil) - if _, err := a.srv.AgentMonitor(nil, req); !acl.IsErrPermissionDenied(err) { - t.Fatalf("err: %v", err) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + if http.StatusForbidden != resp.Code { + t.Fatalf("expected 403 but got %v", resp.Code) } // This proves we call the ACL function, and we've got the other monitor @@ -5270,11 +5290,12 @@ func TestAgent_TokenTriggersFullSync(t *testing.T) { require.NoError(t, err) resp := httptest.NewRecorder() - obj, err := a.srv.ACLPolicyCreate(resp, req) - require.NoError(t, err) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) - policy, ok := obj.(*structs.ACLPolicy) - require.True(t, ok) + dec := json.NewDecoder(resp.Body) + policy = &structs.ACLPolicy{} + require.NoError(t, dec.Decode(policy)) return policy } @@ -5292,11 +5313,12 @@ func TestAgent_TokenTriggersFullSync(t *testing.T) { require.NoError(t, err) resp := httptest.NewRecorder() - obj, err := a.srv.ACLTokenCreate(resp, req) - require.NoError(t, err) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) - token, ok := obj.(*structs.ACLToken) - require.True(t, ok) + dec := json.NewDecoder(resp.Body) + token = &structs.ACLToken{} + require.NoError(t, dec.Decode(token)) return token } @@ -5674,9 +5696,9 @@ func TestAgentConnectCARoots_empty(t *testing.T) { req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/roots", nil) resp := httptest.NewRecorder() - _, err := a.srv.AgentConnectCARoots(resp, req) - require.Error(err) - require.Contains(err.Error(), "Connect must be enabled") + a.srv.h.ServeHTTP(resp, req) + require.Equal(http.StatusInternalServerError, resp.Code) + require.Contains(resp.Body.String(), "Connect must be enabled") } func TestAgentConnectCARoots_list(t *testing.T) { @@ -5699,10 +5721,12 @@ func TestAgentConnectCARoots_list(t *testing.T) { // List req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/roots", nil) resp := httptest.NewRecorder() - obj, err := a.srv.AgentConnectCARoots(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) + + dec := json.NewDecoder(resp.Body) + value := &structs.IndexedCARoots{} + require.NoError(dec.Decode(value)) - value := obj.(structs.IndexedCARoots) assert.Equal(value.ActiveRootID, ca2.ID) // Would like to assert that it's the same as the TestAgent domain but the // only way to access that state via this package is by RPC to the server @@ -5722,9 +5746,12 @@ func TestAgentConnectCARoots_list(t *testing.T) { { // List it again resp2 := httptest.NewRecorder() - obj2, err := a.srv.AgentConnectCARoots(resp2, req) - require.NoError(err) - assert.Equal(obj, obj2) + a.srv.h.ServeHTTP(resp2, req) + + dec := json.NewDecoder(resp2.Body) + value2 := &structs.IndexedCARoots{} + require.NoError(dec.Decode(value2)) + assert.Equal(value, value2) // Should cache hit this time and not make request assert.Equal("HIT", resp2.Header().Get("X-Cache")) @@ -5738,10 +5765,11 @@ func TestAgentConnectCARoots_list(t *testing.T) { retry.Run(t, func(r *retry.R) { // List it again resp := httptest.NewRecorder() - obj, err := a.srv.AgentConnectCARoots(resp, req) - r.Check(err) + a.srv.h.ServeHTTP(resp, req) - value := obj.(structs.IndexedCARoots) + dec := json.NewDecoder(resp.Body) + value := &structs.IndexedCARoots{} + require.NoError(dec.Decode(value)) if ca.ID != value.ActiveRootID { r.Fatalf("%s != %s", ca.ID, value.ActiveRootID) } @@ -5789,16 +5817,14 @@ func TestAgentConnectCALeafCert_aclDefaultDeny(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=root", jsonReader(reg)) resp := httptest.NewRecorder() - _, err := a.srv.AgentRegisterService(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) require.Equal(200, resp.Code, "body: %s", resp.Body.String()) } req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/leaf/test", nil) resp := httptest.NewRecorder() - _, err := a.srv.AgentConnectCALeafCert(resp, req) - require.Error(err) - require.True(acl.IsErrPermissionDenied(err)) + a.srv.h.ServeHTTP(resp, req) + require.Equal(http.StatusForbidden, resp.Code) } func TestAgentConnectCALeafCert_aclServiceWrite(t *testing.T) { @@ -5829,8 +5855,7 @@ func TestAgentConnectCALeafCert_aclServiceWrite(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=root", jsonReader(reg)) resp := httptest.NewRecorder() - _, err := a.srv.AgentRegisterService(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) require.Equal(200, resp.Code, "body: %s", resp.Body.String()) } @@ -5838,12 +5863,13 @@ func TestAgentConnectCALeafCert_aclServiceWrite(t *testing.T) { req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/leaf/test?token="+token, nil) resp := httptest.NewRecorder() - obj, err := a.srv.AgentConnectCALeafCert(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) // Get the issued cert - _, ok := obj.(*structs.IssuedCert) - require.True(ok) + dec := json.NewDecoder(resp.Body) + value := &structs.IssuedCert{} + require.NoError(dec.Decode(value)) + require.NotNil(value) } func createACLTokenWithServicePolicy(t *testing.T, srv *HTTPHandlers, policy string) string { @@ -5864,10 +5890,11 @@ func createACLTokenWithServicePolicy(t *testing.T, srv *HTTPHandlers, policy str req, _ = http.NewRequest("PUT", "/v1/acl/token?token=root", jsonReader(tokenReq)) resp = httptest.NewRecorder() - tokInf, err := srv.ACLTokenCreate(resp, req) - require.NoError(t, err) - svcToken, ok := tokInf.(*structs.ACLToken) - require.True(t, ok) + srv.h.ServeHTTP(resp, req) + + dec := json.NewDecoder(resp.Body) + svcToken := &structs.ACLToken{} + require.NoError(t, dec.Decode(svcToken)) return svcToken.SecretID } @@ -5899,8 +5926,7 @@ func TestAgentConnectCALeafCert_aclServiceReadDeny(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/agent/service/register?token=root", jsonReader(reg)) resp := httptest.NewRecorder() - _, err := a.srv.AgentRegisterService(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) require.Equal(200, resp.Code, "body: %s", resp.Body.String()) } @@ -5908,9 +5934,8 @@ func TestAgentConnectCALeafCert_aclServiceReadDeny(t *testing.T) { req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/leaf/test?token="+token, nil) resp := httptest.NewRecorder() - _, err := a.srv.AgentConnectCALeafCert(resp, req) - require.Error(err) - require.True(acl.IsErrPermissionDenied(err)) + a.srv.h.ServeHTTP(resp, req) + require.Equal(http.StatusForbidden, resp.Code) } func TestAgentConnectCALeafCert_good(t *testing.T) { @@ -5948,8 +5973,7 @@ func TestAgentConnectCALeafCert_good(t *testing.T) { } req, _ := http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) resp := httptest.NewRecorder() - _, err := a.srv.AgentRegisterService(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) if !assert.Equal(200, resp.Code) { t.Log("Body: ", resp.Body.String()) } @@ -5958,13 +5982,13 @@ func TestAgentConnectCALeafCert_good(t *testing.T) { // List req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/leaf/test", nil) resp := httptest.NewRecorder() - obj, err := a.srv.AgentConnectCALeafCert(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) require.Equal("MISS", resp.Header().Get("X-Cache")) // Get the issued cert - issued, ok := obj.(*structs.IssuedCert) - assert.True(ok) + dec := json.NewDecoder(resp.Body) + issued := &structs.IssuedCert{} + require.NoError(dec.Decode(issued)) // Verify that the cert is signed by the CA requireLeafValidUnderCA(t, issued, ca1) @@ -5980,9 +6004,11 @@ func TestAgentConnectCALeafCert_good(t *testing.T) { { // Fetch it again resp := httptest.NewRecorder() - obj2, err := a.srv.AgentConnectCALeafCert(resp, req) - require.NoError(err) - require.Equal(obj, obj2) + a.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + issued2 := &structs.IssuedCert{} + require.NoError(dec.Decode(issued2)) + require.Equal(issued, issued2) } // Set a new CA @@ -5992,9 +6018,10 @@ func TestAgentConnectCALeafCert_good(t *testing.T) { { resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/leaf/test?index="+index, nil) - obj, err := a.srv.AgentConnectCALeafCert(resp, req) - require.NoError(err) - issued2 := obj.(*structs.IssuedCert) + a.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + issued2 := &structs.IssuedCert{} + require.NoError(dec.Decode(issued2)) require.NotEqual(issued.CertPEM, issued2.CertPEM) require.NotEqual(issued.PrivateKeyPEM, issued2.PrivateKeyPEM) @@ -6091,8 +6118,7 @@ func TestAgentConnectCALeafCert_goodNotLocal(t *testing.T) { } req, _ := http.NewRequest("PUT", "/v1/catalog/register", jsonReader(args)) resp := httptest.NewRecorder() - _, err := a.srv.CatalogRegister(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) if !assert.Equal(200, resp.Code) { t.Log("Body: ", resp.Body.String()) } @@ -6101,13 +6127,13 @@ func TestAgentConnectCALeafCert_goodNotLocal(t *testing.T) { // List req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/leaf/test", nil) resp := httptest.NewRecorder() - obj, err := a.srv.AgentConnectCALeafCert(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) require.Equal("MISS", resp.Header().Get("X-Cache")) // Get the issued cert - issued, ok := obj.(*structs.IssuedCert) - assert.True(ok) + dec := json.NewDecoder(resp.Body) + issued := &structs.IssuedCert{} + require.NoError(dec.Decode(issued)) // Verify that the cert is signed by the CA requireLeafValidUnderCA(t, issued, ca1) @@ -6121,9 +6147,11 @@ func TestAgentConnectCALeafCert_goodNotLocal(t *testing.T) { { // Fetch it again resp := httptest.NewRecorder() - obj2, err := a.srv.AgentConnectCALeafCert(resp, req) - require.NoError(err) - require.Equal(obj, obj2) + a.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + issued2 := &structs.IssuedCert{} + require.NoError(dec.Decode(issued2)) + require.Equal(issued, issued2) } // Test Blocking - see https://github.com/hashicorp/consul/issues/4462 @@ -6133,7 +6161,7 @@ func TestAgentConnectCALeafCert_goodNotLocal(t *testing.T) { blockingReq, _ := http.NewRequest("GET", fmt.Sprintf("/v1/agent/connect/ca/leaf/test?wait=125ms&index=%d", issued.ModifyIndex), nil) doneCh := make(chan struct{}) go func() { - a.srv.AgentConnectCALeafCert(resp, blockingReq) + a.srv.h.ServeHTTP(resp, blockingReq) close(doneCh) }() @@ -6154,10 +6182,11 @@ func TestAgentConnectCALeafCert_goodNotLocal(t *testing.T) { resp := httptest.NewRecorder() // Try and sign again (note no index/wait arg since cache should update in // background even if we aren't actively blocking) - obj, err := a.srv.AgentConnectCALeafCert(resp, req) - r.Check(err) + a.srv.h.ServeHTTP(resp, req) - issued2 := obj.(*structs.IssuedCert) + dec := json.NewDecoder(resp.Body) + issued2 := &structs.IssuedCert{} + require.NoError(dec.Decode(issued2)) if issued.CertPEM == issued2.CertPEM { r.Fatalf("leaf has not updated") } @@ -6233,8 +6262,7 @@ func TestAgentConnectCALeafCert_Vault_doesNotChurnLeafCertsAtIdle(t *testing.T) } req, _ := http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) resp := httptest.NewRecorder() - _, err := a.srv.AgentRegisterService(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) if !assert.Equal(200, resp.Code) { t.Log("Body: ", resp.Body.String()) } @@ -6243,13 +6271,13 @@ func TestAgentConnectCALeafCert_Vault_doesNotChurnLeafCertsAtIdle(t *testing.T) // List req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/leaf/test", nil) resp := httptest.NewRecorder() - obj, err := a.srv.AgentConnectCALeafCert(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) require.Equal("MISS", resp.Header().Get("X-Cache")) // Get the issued cert - issued, ok := obj.(*structs.IssuedCert) - assert.True(ok) + dec := json.NewDecoder(resp.Body) + issued := &structs.IssuedCert{} + require.NoError(dec.Decode(issued)) // Verify that the cert is signed by the CA requireLeafValidUnderCA(t, issued, ca1) @@ -6263,9 +6291,11 @@ func TestAgentConnectCALeafCert_Vault_doesNotChurnLeafCertsAtIdle(t *testing.T) { // Fetch it again resp := httptest.NewRecorder() - obj2, err := a.srv.AgentConnectCALeafCert(resp, req) - require.NoError(err) - require.Equal(obj, obj2) + a.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + issued2 := &structs.IssuedCert{} + require.NoError(dec.Decode(issued2)) + require.Equal(issued, issued2) } // Test that we aren't churning leaves for no reason at idle. @@ -6274,11 +6304,17 @@ func TestAgentConnectCALeafCert_Vault_doesNotChurnLeafCertsAtIdle(t *testing.T) go func() { req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/leaf/test?index="+strconv.Itoa(int(issued.ModifyIndex)), nil) resp := httptest.NewRecorder() - obj, err := a.srv.AgentConnectCALeafCert(resp, req) - if err != nil { + a.srv.h.ServeHTTP(resp, req) + if resp.Code != http.StatusOK { + ch <- fmt.Errorf(resp.Body.String()) + return + } + + dec := json.NewDecoder(resp.Body) + issued2 := &structs.IssuedCert{} + if err := dec.Decode(issued2); err != nil { ch <- err } else { - issued2 := obj.(*structs.IssuedCert) if issued.CertPEM == issued2.CertPEM { ch <- fmt.Errorf("leaf woke up unexpectedly with same cert") } else { @@ -6288,7 +6324,6 @@ func TestAgentConnectCALeafCert_Vault_doesNotChurnLeafCertsAtIdle(t *testing.T) }() start := time.Now() - select { case <-time.After(5 * time.Second): case err := <-ch: @@ -6364,8 +6399,7 @@ func TestAgentConnectCALeafCert_secondaryDC_good(t *testing.T) { } req, _ := http.NewRequest("PUT", "/v1/agent/service/register", jsonReader(args)) resp := httptest.NewRecorder() - _, err := a2.srv.AgentRegisterService(resp, req) - require.NoError(err) + a2.srv.h.ServeHTTP(resp, req) if !assert.Equal(200, resp.Code) { t.Log("Body: ", resp.Body.String()) } @@ -6375,13 +6409,14 @@ func TestAgentConnectCALeafCert_secondaryDC_good(t *testing.T) { req, err := http.NewRequest("GET", "/v1/agent/connect/ca/leaf/test", nil) require.NoError(err) resp := httptest.NewRecorder() - obj, err := a2.srv.AgentConnectCALeafCert(resp, req) - require.NoError(err) + a2.srv.h.ServeHTTP(resp, req) + require.Equal(http.StatusOK, resp.Code) require.Equal("MISS", resp.Header().Get("X-Cache")) // Get the issued cert - issued, ok := obj.(*structs.IssuedCert) - assert.True(ok) + dec := json.NewDecoder(resp.Body) + issued := &structs.IssuedCert{} + require.NoError(dec.Decode(issued)) // Verify that the cert is signed by the CA requireLeafValidUnderCA(t, issued, dc1_ca1) @@ -6395,9 +6430,11 @@ func TestAgentConnectCALeafCert_secondaryDC_good(t *testing.T) { { // Fetch it again resp := httptest.NewRecorder() - obj2, err := a2.srv.AgentConnectCALeafCert(resp, req) - require.NoError(err) - require.Equal(obj, obj2) + a2.srv.h.ServeHTTP(resp, req) + dec := json.NewDecoder(resp.Body) + issued2 := &structs.IssuedCert{} + require.NoError(dec.Decode(issued2)) + require.Equal(issued, issued2) } // Test that we aren't churning leaves for no reason at idle. @@ -6406,11 +6443,17 @@ func TestAgentConnectCALeafCert_secondaryDC_good(t *testing.T) { go func() { req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/leaf/test?index="+strconv.Itoa(int(issued.ModifyIndex)), nil) resp := httptest.NewRecorder() - obj, err := a2.srv.AgentConnectCALeafCert(resp, req) - if err != nil { + a2.srv.h.ServeHTTP(resp, req) + if resp.Code != http.StatusOK { + ch <- fmt.Errorf(resp.Body.String()) + return + } + + dec := json.NewDecoder(resp.Body) + issued2 := &structs.IssuedCert{} + if err := dec.Decode(issued2); err != nil { ch <- err } else { - issued2 := obj.(*structs.IssuedCert) if issued.CertPEM == issued2.CertPEM { ch <- fmt.Errorf("leaf woke up unexpectedly with same cert") } else { @@ -6445,10 +6488,12 @@ func TestAgentConnectCALeafCert_secondaryDC_good(t *testing.T) { resp := httptest.NewRecorder() // Try and sign again (note no index/wait arg since cache should update in // background even if we aren't actively blocking) - obj, err := a2.srv.AgentConnectCALeafCert(resp, req) - r.Check(err) + a2.srv.h.ServeHTTP(resp, req) + require.Equal(http.StatusOK, resp.Code) - issued2 := obj.(*structs.IssuedCert) + dec := json.NewDecoder(resp.Body) + issued2 := &structs.IssuedCert{} + require.NoError(dec.Decode(issued2)) if issued.CertPEM == issued2.CertPEM { r.Fatalf("leaf has not updated") } @@ -6470,15 +6515,14 @@ func waitForActiveCARoot(t *testing.T, srv *HTTPHandlers, expect *structs.CARoot retry.Run(t, func(r *retry.R) { req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/roots", nil) resp := httptest.NewRecorder() - obj, err := srv.AgentConnectCARoots(resp, req) - if err != nil { - r.Fatalf("err: %v", err) + srv.h.ServeHTTP(resp, req) + if http.StatusOK != resp.Code { + t.Fatalf("expected 200 but got %v", resp.Code) } - roots, ok := obj.(structs.IndexedCARoots) - if !ok { - r.Fatalf("response is wrong type %T", obj) - } + dec := json.NewDecoder(resp.Body) + roots := &structs.IndexedCARoots{} + require.NoError(t, dec.Decode(roots)) var root *structs.CARoot for _, r := range roots.Roots { @@ -6530,12 +6574,9 @@ func TestAgentConnectAuthorize_badBody(t *testing.T) { args := []string{} req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - require.Error(err) - assert.Nil(respRaw) - // Note that BadRequestError is handled outside the endpoint handler so we - // still see a 200 if we check here. - assert.Contains(err.Error(), "decode failed") + a.srv.h.ServeHTTP(resp, req) + require.Equal(http.StatusBadRequest, resp.Code) + assert.Contains(resp.Body.String(), "decode failed") } func TestAgentConnectAuthorize_noTarget(t *testing.T) { @@ -6554,12 +6595,9 @@ func TestAgentConnectAuthorize_noTarget(t *testing.T) { args := &structs.ConnectAuthorizeRequest{} req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - require.Error(err) - assert.Nil(respRaw) - // Note that BadRequestError is handled outside the endpoint handler so we - // still see a 200 if we check here. - assert.Contains(err.Error(), "Target service must be specified") + a.srv.h.ServeHTTP(resp, req) + require.Equal(http.StatusBadRequest, resp.Code) + assert.Contains(resp.Body.String(), "Target service must be specified") } // Client ID is not in the valid URI format @@ -6582,12 +6620,9 @@ func TestAgentConnectAuthorize_idInvalidFormat(t *testing.T) { } req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - require.Error(err) - assert.Nil(respRaw) - // Note that BadRequestError is handled outside the endpoint handler so we - // still see a 200 if we check here. - assert.Contains(err.Error(), "ClientCertURI not a valid Connect identifier") + a.srv.h.ServeHTTP(resp, req) + require.Equal(http.StatusBadRequest, resp.Code) + assert.Contains(resp.Body.String(), "ClientCertURI not a valid Connect identifier") } // Client ID is a valid URI but its not a service URI @@ -6610,12 +6645,9 @@ func TestAgentConnectAuthorize_idNotService(t *testing.T) { } req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - require.Error(err) - assert.Nil(respRaw) - // Note that BadRequestError is handled outside the endpoint handler so we - // still see a 200 if we check here. - assert.Contains(err.Error(), "ClientCertURI not a valid Service identifier") + a.srv.h.ServeHTTP(resp, req) + require.Equal(http.StatusBadRequest, resp.Code) + assert.Contains(resp.Body.String(), "ClientCertURI not a valid Service identifier") } // Test when there is an intention allowing the connection @@ -6656,12 +6688,13 @@ func TestAgentConnectAuthorize_allow(t *testing.T) { } req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - require.Nil(err) + a.srv.h.ServeHTTP(resp, req) require.Equal(200, resp.Code) require.Equal("MISS", resp.Header().Get("X-Cache")) - obj := respRaw.(*connectAuthorizeResp) + dec := json.NewDecoder(resp.Body) + obj := &connectAuthorizeResp{} + require.NoError(dec.Decode(obj)) require.True(obj.Authorized) require.Contains(obj.Reason, "Matched") @@ -6669,11 +6702,12 @@ func TestAgentConnectAuthorize_allow(t *testing.T) { { req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - require.Nil(err) + a.srv.h.ServeHTTP(resp, req) require.Equal(200, resp.Code) - obj := respRaw.(*connectAuthorizeResp) + dec := json.NewDecoder(resp.Body) + obj := &connectAuthorizeResp{} + require.NoError(dec.Decode(obj)) require.True(obj.Authorized) require.Contains(obj.Reason, "Matched") @@ -6705,11 +6739,12 @@ func TestAgentConnectAuthorize_allow(t *testing.T) { { req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - require.Nil(err) + a.srv.h.ServeHTTP(resp, req) require.Equal(200, resp.Code) - obj := respRaw.(*connectAuthorizeResp) + dec := json.NewDecoder(resp.Body) + obj := &connectAuthorizeResp{} + require.NoError(dec.Decode(obj)) require.False(obj.Authorized) require.Contains(obj.Reason, "Matched") @@ -6757,11 +6792,12 @@ func TestAgentConnectAuthorize_deny(t *testing.T) { } req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - assert.Nil(err) + a.srv.h.ServeHTTP(resp, req) assert.Equal(200, resp.Code) - obj := respRaw.(*connectAuthorizeResp) + dec := json.NewDecoder(resp.Body) + obj := &connectAuthorizeResp{} + require.NoError(t, dec.Decode(obj)) assert.False(obj.Authorized) assert.Contains(obj.Reason, "Matched") } @@ -6812,11 +6848,12 @@ func TestAgentConnectAuthorize_allowTrustDomain(t *testing.T) { } req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) assert.Equal(200, resp.Code) - obj := respRaw.(*connectAuthorizeResp) + dec := json.NewDecoder(resp.Body) + obj := &connectAuthorizeResp{} + require.NoError(dec.Decode(obj)) require.True(obj.Authorized) require.Contains(obj.Reason, "Matched") } @@ -6879,11 +6916,12 @@ func TestAgentConnectAuthorize_denyWildcard(t *testing.T) { } req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) assert.Equal(200, resp.Code) - obj := respRaw.(*connectAuthorizeResp) + dec := json.NewDecoder(resp.Body) + obj := &connectAuthorizeResp{} + require.NoError(dec.Decode(obj)) assert.True(obj.Authorized) assert.Contains(obj.Reason, "Matched") } @@ -6896,11 +6934,12 @@ func TestAgentConnectAuthorize_denyWildcard(t *testing.T) { } req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - require.NoError(err) + a.srv.h.ServeHTTP(resp, req) assert.Equal(200, resp.Code) - obj := respRaw.(*connectAuthorizeResp) + dec := json.NewDecoder(resp.Body) + obj := &connectAuthorizeResp{} + require.NoError(dec.Decode(obj)) assert.False(obj.Authorized) assert.Contains(obj.Reason, "Matched") } @@ -6928,8 +6967,9 @@ func TestAgentConnectAuthorize_serviceWrite(t *testing.T) { req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize?token="+token, jsonReader(args)) resp := httptest.NewRecorder() - _, err := a.srv.AgentConnectAuthorize(resp, req) - assert.True(acl.IsErrPermissionDenied(err)) + a.srv.h.ServeHTTP(resp, req) + + assert.Equal(http.StatusForbidden, resp.Code) } // Test when no intentions match w/ a default deny policy @@ -6951,11 +6991,12 @@ func TestAgentConnectAuthorize_defaultDeny(t *testing.T) { } req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize?token=root", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - assert.Nil(err) + a.srv.h.ServeHTTP(resp, req) assert.Equal(200, resp.Code) - obj := respRaw.(*connectAuthorizeResp) + dec := json.NewDecoder(resp.Body) + obj := &connectAuthorizeResp{} + require.NoError(t, dec.Decode(obj)) assert.False(obj.Authorized) assert.Contains(obj.Reason, "Default behavior") } @@ -6986,12 +7027,12 @@ func TestAgentConnectAuthorize_defaultAllow(t *testing.T) { } req, _ := http.NewRequest("POST", "/v1/agent/connect/authorize?token=root", jsonReader(args)) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentConnectAuthorize(resp, req) - assert.Nil(err) + a.srv.h.ServeHTTP(resp, req) assert.Equal(200, resp.Code) - assert.NotNil(respRaw) - obj := respRaw.(*connectAuthorizeResp) + dec := json.NewDecoder(resp.Body) + obj := &connectAuthorizeResp{} + require.NoError(t, dec.Decode(obj)) assert.True(obj.Authorized) assert.Contains(obj.Reason, "Default behavior") } @@ -7017,6 +7058,7 @@ func TestAgent_Host(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") req, _ := http.NewRequest("GET", "/v1/agent/host?token=master", nil) resp := httptest.NewRecorder() + // TODO: AgentHost should write to response so that we can test using ServeHTTP() respRaw, err := a.srv.AgentHost(resp, req) assert.Nil(err) assert.Equal(http.StatusOK, resp.Code) @@ -7048,10 +7090,10 @@ func TestAgent_HostBadACL(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") req, _ := http.NewRequest("GET", "/v1/agent/host?token=agent", nil) resp := httptest.NewRecorder() - respRaw, err := a.srv.AgentHost(resp, req) + // TODO: AgentHost should write to response so that we can test using ServeHTTP() + _, err := a.srv.AgentHost(resp, req) assert.EqualError(err, "ACL not found") assert.Equal(http.StatusOK, resp.Code) - assert.Nil(respRaw) } // Thie tests that a proxy with an ExposeConfig is returned as expected. @@ -7088,12 +7130,19 @@ func TestAgent_Services_ExposeConfig(t *testing.T) { a.State.AddService(srv1, "") req, _ := http.NewRequest("GET", "/v1/agent/services", nil) - obj, err := a.srv.AgentServices(httptest.NewRecorder(), req) - require.NoError(t, err) - val := obj.(map[string]*api.AgentService) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + dec := json.NewDecoder(resp.Body) + val := make(map[string]*api.AgentService) + require.NoError(t, dec.Decode(&val)) require.Len(t, val, 1) actual := val["proxy-id"] require.NotNil(t, actual) require.Equal(t, api.ServiceKindConnectProxy, actual.Kind) + // Proxy.ToAPI() creates an empty Upstream list instead of keeping nil so do the same with actual. + if actual.Proxy.Upstreams == nil { + actual.Proxy.Upstreams = make([]api.Upstream, 0) + } require.Equal(t, srv1.Proxy.ToAPI(), actual.Proxy) }