resource: add missing validation to the `List` and `WatchList` endpoints (#17213)

This commit is contained in:
Dan Upton 2023-05-10 10:38:48 +01:00 committed by GitHub
parent 6c24a66f73
commit 5030101cdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 0 deletions

View File

@ -15,6 +15,10 @@ import (
)
func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbresource.ListResponse, error) {
if err := validateListRequest(req); err != nil {
return nil, err
}
// check type
reg, err := s.resolveType(req.Type)
if err != nil {
@ -65,3 +69,16 @@ func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbreso
}
return &pbresource.ListResponse{Resources: result}, nil
}
func validateListRequest(req *pbresource.ListRequest) error {
var field string
switch {
case req.Type == nil:
field = "type"
case req.Tenancy == nil:
field = "tenancy"
default:
return nil
}
return status.Errorf(codes.InvalidArgument, "%s is required", field)
}

View File

@ -22,6 +22,31 @@ import (
"google.golang.org/grpc/status"
)
func TestList_InputValidation(t *testing.T) {
server := testServer(t)
client := testClient(t, server)
demo.RegisterTypes(server.Registry)
testCases := map[string]func(*pbresource.ListRequest){
"no type": func(req *pbresource.ListRequest) { req.Type = nil },
"no tenancy": func(req *pbresource.ListRequest) { req.Tenancy = nil },
}
for desc, modFn := range testCases {
t.Run(desc, func(t *testing.T) {
req := &pbresource.ListRequest{
Type: demo.TypeV2Album,
Tenancy: demo.TenancyDefault,
}
modFn(req)
_, err := client.List(testContext(t), req)
require.Error(t, err)
require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String())
})
}
}
func TestList_TypeNotFound(t *testing.T) {
server := testServer(t)
client := testClient(t, server)

View File

@ -13,6 +13,10 @@ import (
)
func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.ResourceService_WatchListServer) error {
if err := validateWatchListRequest(req); err != nil {
return err
}
// check type exists
reg, err := s.resolveType(req.Type)
if err != nil {
@ -70,3 +74,16 @@ func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.R
}
}
}
func validateWatchListRequest(req *pbresource.WatchListRequest) error {
var field string
switch {
case req.Type == nil:
field = "type"
case req.Tenancy == nil:
field = "tenancy"
default:
return nil
}
return status.Errorf(codes.InvalidArgument, "%s is required", field)
}

View File

@ -22,6 +22,34 @@ import (
"google.golang.org/grpc/status"
)
func TestWatchList_InputValidation(t *testing.T) {
server := testServer(t)
client := testClient(t, server)
demo.RegisterTypes(server.Registry)
testCases := map[string]func(*pbresource.WatchListRequest){
"no type": func(req *pbresource.WatchListRequest) { req.Type = nil },
"no tenancy": func(req *pbresource.WatchListRequest) { req.Tenancy = nil },
}
for desc, modFn := range testCases {
t.Run(desc, func(t *testing.T) {
req := &pbresource.WatchListRequest{
Type: demo.TypeV2Album,
Tenancy: demo.TenancyDefault,
}
modFn(req)
stream, err := client.WatchList(testContext(t), req)
require.NoError(t, err)
_, err = stream.Recv()
require.Error(t, err)
require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String())
})
}
}
func TestWatchList_TypeNotFound(t *testing.T) {
t.Parallel()