diff --git a/agent/grpc-external/services/resource/delete.go b/agent/grpc-external/services/resource/delete.go index a7e439c33a..da3e0e3c39 100644 --- a/agent/grpc-external/services/resource/delete.go +++ b/agent/grpc-external/services/resource/delete.go @@ -21,6 +21,10 @@ import ( // // TODO(spatel): Move docs to the proto file func (s *Server) Delete(ctx context.Context, req *pbresource.DeleteRequest) (*pbresource.DeleteResponse, error) { + if err := validateDeleteRequest(req); err != nil { + return nil, err + } + reg, err := s.resolveType(req.Id.Type) if err != nil { return nil, err @@ -72,3 +76,14 @@ func (s *Server) Delete(ctx context.Context, req *pbresource.DeleteRequest) (*pb return nil, status.Errorf(codes.Internal, "failed delete: %v", err) } } + +func validateDeleteRequest(req *pbresource.DeleteRequest) error { + if req.Id == nil { + return status.Errorf(codes.InvalidArgument, "id is required") + } + + if err := validateId(req.Id, "id"); err != nil { + return err + } + return nil +} diff --git a/agent/grpc-external/services/resource/delete_test.go b/agent/grpc-external/services/resource/delete_test.go index 9b68de93a6..c6b93a9e0d 100644 --- a/agent/grpc-external/services/resource/delete_test.go +++ b/agent/grpc-external/services/resource/delete_test.go @@ -15,6 +15,46 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +func TestDelete_InputValidation(t *testing.T) { + server := testServer(t) + client := testClient(t, server) + + demo.Register(server.Registry) + + testCases := map[string]func(*pbresource.DeleteRequest){ + "no id": func(req *pbresource.DeleteRequest) { req.Id = nil }, + "no type": func(req *pbresource.DeleteRequest) { req.Id.Type = nil }, + "no tenancy": func(req *pbresource.DeleteRequest) { req.Id.Tenancy = nil }, + "no name": func(req *pbresource.DeleteRequest) { req.Id.Name = "" }, + // clone necessary to not pollute DefaultTenancy + "tenancy partition wildcard": func(req *pbresource.DeleteRequest) { + req.Id.Tenancy = clone(req.Id.Tenancy) + req.Id.Tenancy.Partition = storage.Wildcard + }, + "tenancy namespace wildcard": func(req *pbresource.DeleteRequest) { + req.Id.Tenancy = clone(req.Id.Tenancy) + req.Id.Tenancy.Namespace = storage.Wildcard + }, + "tenancy peername wildcard": func(req *pbresource.DeleteRequest) { + req.Id.Tenancy = clone(req.Id.Tenancy) + req.Id.Tenancy.PeerName = storage.Wildcard + }, + } + for desc, modFn := range testCases { + t.Run(desc, func(t *testing.T) { + res, err := demo.GenerateV2Artist() + require.NoError(t, err) + + req := &pbresource.DeleteRequest{Id: res.Id, Version: ""} + modFn(req) + + _, err = client.Delete(testContext(t), req) + require.Error(t, err) + require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String()) + }) + } +} + func TestDelete_TypeNotRegistered(t *testing.T) { t.Parallel() diff --git a/agent/grpc-external/services/resource/read.go b/agent/grpc-external/services/resource/read.go index 2e8988970c..c75779183c 100644 --- a/agent/grpc-external/services/resource/read.go +++ b/agent/grpc-external/services/resource/read.go @@ -16,6 +16,10 @@ import ( ) func (s *Server) Read(ctx context.Context, req *pbresource.ReadRequest) (*pbresource.ReadResponse, error) { + if err := validateReadRequest(req); err != nil { + return nil, err + } + // check type exists reg, err := s.resolveType(req.Id.Type) if err != nil { @@ -48,3 +52,14 @@ func (s *Server) Read(ctx context.Context, req *pbresource.ReadRequest) (*pbreso return nil, status.Errorf(codes.Internal, "failed read: %v", err) } } + +func validateReadRequest(req *pbresource.ReadRequest) error { + if req.Id == nil { + return status.Errorf(codes.InvalidArgument, "id is required") + } + + if err := validateId(req.Id, "id"); err != nil { + return err + } + return nil +} diff --git a/agent/grpc-external/services/resource/read_test.go b/agent/grpc-external/services/resource/read_test.go index 86ebd59a4f..6f7f80a090 100644 --- a/agent/grpc-external/services/resource/read_test.go +++ b/agent/grpc-external/services/resource/read_test.go @@ -21,6 +21,46 @@ import ( "github.com/hashicorp/consul/proto/private/prototest" ) +func TestRead_InputValidation(t *testing.T) { + server := testServer(t) + client := testClient(t, server) + + demo.Register(server.Registry) + + testCases := map[string]func(*pbresource.ReadRequest){ + "no id": func(req *pbresource.ReadRequest) { req.Id = nil }, + "no type": func(req *pbresource.ReadRequest) { req.Id.Type = nil }, + "no tenancy": func(req *pbresource.ReadRequest) { req.Id.Tenancy = nil }, + "no name": func(req *pbresource.ReadRequest) { req.Id.Name = "" }, + // clone necessary to not pollute DefaultTenancy + "tenancy partition wildcard": func(req *pbresource.ReadRequest) { + req.Id.Tenancy = clone(req.Id.Tenancy) + req.Id.Tenancy.Partition = storage.Wildcard + }, + "tenancy namespace wildcard": func(req *pbresource.ReadRequest) { + req.Id.Tenancy = clone(req.Id.Tenancy) + req.Id.Tenancy.Namespace = storage.Wildcard + }, + "tenancy peername wildcard": func(req *pbresource.ReadRequest) { + req.Id.Tenancy = clone(req.Id.Tenancy) + req.Id.Tenancy.PeerName = storage.Wildcard + }, + } + for desc, modFn := range testCases { + t.Run(desc, func(t *testing.T) { + res, err := demo.GenerateV2Artist() + require.NoError(t, err) + + req := &pbresource.ReadRequest{Id: res.Id} + modFn(req) + + _, err = client.Read(testContext(t), req) + require.Error(t, err) + require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String()) + }) + } +} + func TestRead_TypeNotFound(t *testing.T) { server := NewServer(Config{Registry: resource.NewRegistry()}) client := testClient(t, server) diff --git a/agent/grpc-external/services/resource/server.go b/agent/grpc-external/services/resource/server.go index 921cb62b17..46bbbef17e 100644 --- a/agent/grpc-external/services/resource/server.go +++ b/agent/grpc-external/services/resource/server.go @@ -116,4 +116,34 @@ func isGRPCStatusError(err error) bool { return ok } +func validateId(id *pbresource.ID, errorPrefix string) error { + var field string + switch { + case id.Type == nil: + field = "type" + case id.Tenancy == nil: + field = "tenancy" + case id.Name == "": + field = "name" + } + + if field != "" { + return status.Errorf(codes.InvalidArgument, "%s.%s is required", errorPrefix, field) + } + + switch { + case id.Tenancy.Namespace == storage.Wildcard: + field = "tenancy.namespace" + case id.Tenancy.Partition == storage.Wildcard: + field = "tenancy.partition" + case id.Tenancy.PeerName == storage.Wildcard: + field = "tenancy.peername" + } + + if field != "" { + return status.Errorf(codes.InvalidArgument, "%s.%s cannot be a wildcard", errorPrefix, field) + } + return nil +} + func clone[T proto.Message](v T) T { return proto.Clone(v).(T) } diff --git a/agent/grpc-external/services/resource/write.go b/agent/grpc-external/services/resource/write.go index 65455adf4d..cccc5097a1 100644 --- a/agent/grpc-external/services/resource/write.go +++ b/agent/grpc-external/services/resource/write.go @@ -226,18 +226,22 @@ func validateWriteRequest(req *pbresource.WriteRequest) error { field = "resource" case req.Resource.Id == nil: field = "resource.id" - case req.Resource.Id.Type == nil: - field = "resource.id.type" - case req.Resource.Id.Tenancy == nil: - field = "resource.id.tenancy" - case req.Resource.Id.Name == "": - field = "resource.id.name" case req.Resource.Data == nil: field = "resource.data" } - if field == "" { - return nil + if field != "" { + return status.Errorf(codes.InvalidArgument, "%s is required", field) } - return status.Errorf(codes.InvalidArgument, "%s is required", field) + + if err := validateId(req.Resource.Id, "resource.id"); err != nil { + return err + } + + if req.Resource.Owner != nil { + if err := validateId(req.Resource.Owner, "resource.owner"); err != nil { + return err + } + } + return nil } diff --git a/agent/grpc-external/services/resource/write_test.go b/agent/grpc-external/services/resource/write_test.go index d1d561be9b..c1d00a7283 100644 --- a/agent/grpc-external/services/resource/write_test.go +++ b/agent/grpc-external/services/resource/write_test.go @@ -32,6 +32,19 @@ func TestWrite_InputValidation(t *testing.T) { "no tenancy": func(req *pbresource.WriteRequest) { req.Resource.Id.Tenancy = nil }, "no name": func(req *pbresource.WriteRequest) { req.Resource.Id.Name = "" }, "no data": func(req *pbresource.WriteRequest) { req.Resource.Data = nil }, + // clone necessary to not pollute DefaultTenancy + "tenancy partition wildcard": func(req *pbresource.WriteRequest) { + req.Resource.Id.Tenancy = clone(req.Resource.Id.Tenancy) + req.Resource.Id.Tenancy.Partition = storage.Wildcard + }, + "tenancy namespace wildcard": func(req *pbresource.WriteRequest) { + req.Resource.Id.Tenancy = clone(req.Resource.Id.Tenancy) + req.Resource.Id.Tenancy.Namespace = storage.Wildcard + }, + "tenancy peername wildcard": func(req *pbresource.WriteRequest) { + req.Resource.Id.Tenancy = clone(req.Resource.Id.Tenancy) + req.Resource.Id.Tenancy.PeerName = storage.Wildcard + }, "wrong data type": func(req *pbresource.WriteRequest) { var err error req.Resource.Data, err = anypb.New(&pbdemov2.Album{}) @@ -59,6 +72,71 @@ func TestWrite_InputValidation(t *testing.T) { } } +func TestWrite_OwnerValidation(t *testing.T) { + server := testServer(t) + client := testClient(t, server) + + demo.Register(server.Registry) + + type testCase struct { + modReqFn func(req *pbresource.WriteRequest) + errorContains string + } + testCases := map[string]testCase{ + "no owner type": { + modReqFn: func(req *pbresource.WriteRequest) { req.Resource.Owner.Type = nil }, + errorContains: "resource.owner.type", + }, + "no owner tenancy": { + modReqFn: func(req *pbresource.WriteRequest) { req.Resource.Owner.Tenancy = nil }, + errorContains: "resource.owner.tenancy", + }, + "no owner name": { + modReqFn: func(req *pbresource.WriteRequest) { req.Resource.Owner.Name = "" }, + errorContains: "resource.owner.name", + }, + // clone necessary to not pollute DefaultTenancy + "owner tenancy partition wildcard": { + modReqFn: func(req *pbresource.WriteRequest) { + req.Resource.Owner.Tenancy = clone(req.Resource.Owner.Tenancy) + req.Resource.Owner.Tenancy.Partition = storage.Wildcard + }, + errorContains: "resource.owner.tenancy.partition", + }, + "owner tenancy namespace wildcard": { + modReqFn: func(req *pbresource.WriteRequest) { + req.Resource.Owner.Tenancy = clone(req.Resource.Owner.Tenancy) + req.Resource.Owner.Tenancy.Namespace = storage.Wildcard + }, + errorContains: "resource.owner.tenancy.namespace", + }, + "owner tenancy peername wildcard": { + modReqFn: func(req *pbresource.WriteRequest) { + req.Resource.Owner.Tenancy = clone(req.Resource.Owner.Tenancy) + req.Resource.Owner.Tenancy.PeerName = storage.Wildcard + }, + errorContains: "resource.owner.tenancy.peername", + }, + } + for desc, tc := range testCases { + t.Run(desc, func(t *testing.T) { + artist, err := demo.GenerateV2Artist() + require.NoError(t, err) + + album, err := demo.GenerateV2Album(artist.Id) + require.NoError(t, err) + + albumReq := &pbresource.WriteRequest{Resource: album} + tc.modReqFn(albumReq) + + _, err = client.Write(testContext(t), albumReq) + require.Error(t, err) + require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String()) + require.ErrorContains(t, err, tc.errorContains) + }) + } +} + func TestWrite_TypeNotFound(t *testing.T) { server := testServer(t) client := testClient(t, server) diff --git a/internal/resource/demo/demo.go b/internal/resource/demo/demo.go index 2c9fe4b433..5c2e9307f3 100644 --- a/internal/resource/demo/demo.go +++ b/internal/resource/demo/demo.go @@ -289,3 +289,5 @@ var ( "Standing by the stage looking cool", } ) + +func clone[T proto.Message](v T) T { return proto.Clone(v).(T) }