From 8fb90aacef43eca3d533f6cb1b87a09aa12fd890 Mon Sep 17 00:00:00 2001 From: Mathew Estafanous <56979977+Mathew-Estafanous@users.noreply.github.com> Date: Wed, 3 Nov 2021 11:12:36 -0400 Subject: [PATCH] Convert (some) test endpoints to use ServeHTTP instead of direct calls to handlers. (#11445) --- agent/agent_endpoint_test.go | 75 +++++++++++++++++++++++------------- 1 file changed, 48 insertions(+), 27 deletions(-) diff --git a/agent/agent_endpoint_test.go b/agent/agent_endpoint_test.go index 8f826ade87..7e8f3d7ffa 100644 --- a/agent/agent_endpoint_test.go +++ b/agent/agent_endpoint_test.go @@ -92,12 +92,15 @@ func TestAgent_Services(t *testing.T) { require.NoError(t, a.State.AddService(srv1, "")) req, _ := http.NewRequest("GET", "/v1/agent/services", nil) - obj, err := a.srv.AgentServices(nil, req) - if err != nil { - t.Fatalf("Err: %v", err) - } - val := obj.(map[string]*api.AgentService) - assert.Lenf(t, val, 1, "bad services: %v", obj) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + decoder := json.NewDecoder(resp.Body) + var val map[string]*api.AgentService + err := decoder.Decode(&val) + require.NoError(t, err) + assert.Lenf(t, val, 1, "bad services: %v", val) assert.Equal(t, 5000, val["mysql"].Port) assert.Equal(t, srv1.Meta, val["mysql"].Meta) } @@ -136,15 +139,25 @@ func TestAgent_ServicesFiltered(t *testing.T) { require.NoError(t, a.State.AddService(srv2, "")) req, _ := http.NewRequest("GET", "/v1/agent/services?filter="+url.QueryEscape("foo in Meta"), nil) - obj, err := a.srv.AgentServices(nil, req) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + decoder := json.NewDecoder(resp.Body) + var val map[string]*api.AgentService + err := decoder.Decode(&val) require.NoError(t, err) - val := obj.(map[string]*api.AgentService) require.Len(t, val, 2) req, _ = http.NewRequest("GET", "/v1/agent/services?filter="+url.QueryEscape("kv in Tags"), nil) - obj, err = a.srv.AgentServices(nil, req) + resp = httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + decoder = json.NewDecoder(resp.Body) + val = make(map[string]*api.AgentService) + err = decoder.Decode(&val) require.NoError(t, err) - val = obj.(map[string]*api.AgentService) require.Len(t, val, 1) } @@ -175,9 +188,13 @@ func TestAgent_Services_ExternalConnectProxy(t *testing.T) { a.State.AddService(srv1, "") req, _ := http.NewRequest("GET", "/v1/agent/services", nil) - obj, err := a.srv.AgentServices(nil, req) - assert.Nil(err) - val := obj.(map[string]*api.AgentService) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + decoder := json.NewDecoder(resp.Body) + var val map[string]*api.AgentService + err := decoder.Decode(&val) + require.NoError(t, err) + assert.Len(val, 1) actual := val["db-proxy"] assert.Equal(api.ServiceKindConnectProxy, actual.Kind) @@ -217,9 +234,13 @@ func TestAgent_Services_Sidecar(t *testing.T) { a.State.AddService(srv1, "") req, _ := http.NewRequest("GET", "/v1/agent/services", nil) - obj, err := a.srv.AgentServices(nil, req) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + decoder := json.NewDecoder(resp.Body) + var val map[string]*api.AgentService + err := decoder.Decode(&val) require.NoError(err) - val := obj.(map[string]*api.AgentService) + assert.Len(val, 1) actual := val["db-sidecar-proxy"] require.NotNil(actual) @@ -232,10 +253,8 @@ func TestAgent_Services_Sidecar(t *testing.T) { // but this test serves as a regression test incase we change the endpoint to // return the internal struct later and accidentally expose some "internal" // state. - output, err := json.Marshal(obj) - require.NoError(err) - assert.NotContains(string(output), "LocallyRegisteredAsSidecar") - assert.NotContains(string(output), "locally_registered_as_sidecar") + assert.NotContains(resp.Body.String(), "LocallyRegisteredAsSidecar") + assert.NotContains(resp.Body.String(), "locally_registered_as_sidecar") } // This tests that a mesh gateway service is returned as expected. @@ -5157,8 +5176,7 @@ func TestAgent_TokenTriggersFullSync(t *testing.T) { require.NoError(t, err) resp := httptest.NewRecorder() - _, err = a.srv.AgentToken(resp, req) - require.NoError(t, err) + a.srv.h.ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) require.Equal(t, token.SecretID, tt.tokenGetFn(a.tokens)) @@ -5234,6 +5252,7 @@ func TestAgent_Token(t *testing.T) { method: "PUT", url: "nope?token=root", body: body("X"), + code: http.StatusNotFound, expectedErr: `Token "nope" is unknown`, }, { @@ -5241,6 +5260,7 @@ func TestAgent_Token(t *testing.T) { method: "PUT", url: "acl_token?token=root", body: badJSON(), + code: http.StatusBadRequest, expectedErr: `Bad request: Request decode failed: json: cannot unmarshal bool into Go value of type api.AgentToken`, }, { @@ -5397,13 +5417,12 @@ func TestAgent_Token(t *testing.T) { resp := httptest.NewRecorder() req, _ := http.NewRequest(tt.method, url, tt.body) - _, err := a.srv.AgentToken(resp, req) + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, tt.code, resp.Code) if tt.expectedErr != "" { - require.EqualError(t, err, tt.expectedErr) + require.Contains(t, resp.Body.String(), tt.expectedErr) return } - require.NoError(t, err) - require.Equal(t, tt.code, resp.Code) require.Equal(t, tt.effective.user, a.tokens.UserToken()) require.Equal(t, tt.effective.agent, a.tokens.AgentToken()) require.Equal(t, tt.effective.master, a.tokens.AgentMasterToken()) @@ -5432,8 +5451,10 @@ func TestAgent_Token(t *testing.T) { t.Run("permission denied", func(t *testing.T) { resetTokens(tokens{}) req, _ := http.NewRequest("PUT", "/v1/agent/token/acl_token", body("X")) - _, err := a.srv.AgentToken(nil, req) - require.True(t, acl.IsErrPermissionDenied(err)) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + + require.Equal(t, http.StatusForbidden, resp.Code) require.Equal(t, "", a.tokens.UserToken()) }) }