fix: write endpoint errors out gracefully (#18743)

This commit is contained in:
Poonam Jadhav 2023-09-12 09:22:15 -04:00 committed by GitHub
parent 697836b19a
commit 264166fcc0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 2 deletions

View File

@ -82,14 +82,18 @@ func (h *resourceHandler) handleWrite(w http.ResponseWriter, r *http.Request, ct
var req writeRequest var req writeRequest
// convert req body to writeRequest // convert req body to writeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.logger.Error("Failed to decode request body", "error", err)
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Request body didn't follow schema.")) w.Write([]byte("Request body format is invalid"))
return
} }
// convert data struct to proto message // convert data struct to proto message
data := h.reg.Proto.ProtoReflect().New().Interface() data := h.reg.Proto.ProtoReflect().New().Interface()
if err := protojson.Unmarshal(req.Data, data); err != nil { if err := protojson.Unmarshal(req.Data, data); err != nil {
h.logger.Error("Failed to unmarshal to proto message", "error", err)
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Request body didn't follow schema.")) w.Write([]byte("Request body didn't follow the resource schema"))
return
} }
// proto message to any // proto message to any
anyProtoMsg, err := anypb.New(data) anyProtoMsg, err := anypb.New(data)

View File

@ -6,6 +6,7 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -41,6 +42,7 @@ func TestResourceHandler_InputValidation(t *testing.T) {
request *http.Request request *http.Request
response *httptest.ResponseRecorder response *httptest.ResponseRecorder
expectedResponseCode int expectedResponseCode int
expectedErrorMessage string
} }
client := svctest.RunResourceService(t, demo.RegisterTypes) client := svctest.RunResourceService(t, demo.RegisterTypes)
resourceHandler := resourceHandler{ resourceHandler := resourceHandler{
@ -70,6 +72,7 @@ func TestResourceHandler_InputValidation(t *testing.T) {
`)), `)),
response: httptest.NewRecorder(), response: httptest.NewRecorder(),
expectedResponseCode: http.StatusBadRequest, expectedResponseCode: http.StatusBadRequest,
expectedErrorMessage: "rpc error: code = InvalidArgument desc = resource.id.name is required",
}, },
{ {
description: "wrong schema", description: "wrong schema",
@ -86,12 +89,21 @@ func TestResourceHandler_InputValidation(t *testing.T) {
`)), `)),
response: httptest.NewRecorder(), response: httptest.NewRecorder(),
expectedResponseCode: http.StatusBadRequest, expectedResponseCode: http.StatusBadRequest,
expectedErrorMessage: "Request body didn't follow the resource schema",
},
{
description: "invalid request body",
request: httptest.NewRequest("PUT", "/keith-urban?partition=default&peer_name=local&namespace=default", strings.NewReader("bad-input")),
response: httptest.NewRecorder(),
expectedResponseCode: http.StatusBadRequest,
expectedErrorMessage: "Request body format is invalid",
}, },
{ {
description: "no id", description: "no id",
request: httptest.NewRequest("DELETE", "/?partition=default&peer_name=local&namespace=default", strings.NewReader("")), request: httptest.NewRequest("DELETE", "/?partition=default&peer_name=local&namespace=default", strings.NewReader("")),
response: httptest.NewRecorder(), response: httptest.NewRecorder(),
expectedResponseCode: http.StatusBadRequest, expectedResponseCode: http.StatusBadRequest,
expectedErrorMessage: "rpc error: code = InvalidArgument desc = id.name is required",
}, },
} }
@ -99,7 +111,15 @@ func TestResourceHandler_InputValidation(t *testing.T) {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
resourceHandler.ServeHTTP(tc.response, tc.request) resourceHandler.ServeHTTP(tc.response, tc.request)
response := tc.response.Result()
defer response.Body.Close()
b, err := io.ReadAll(tc.response.Body)
require.NoError(t, err)
require.Equal(t, tc.expectedResponseCode, tc.response.Result().StatusCode) require.Equal(t, tc.expectedResponseCode, tc.response.Result().StatusCode)
require.Equal(t, tc.expectedErrorMessage, string(b))
}) })
} }
} }