Net-6291/fix/watch resources (#19467)

* fix: update watch endpoint to default based on scope

* test: additional test

* refactor: rename list validate function

* refactor: rename validate<Op>Request() -> ensure<Op>RequestValid() for consistency
This commit is contained in:
Poonam Jadhav 2023-11-03 16:03:07 -04:00 committed by GitHub
parent 65592d91a8
commit c3c836edae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 199 additions and 65 deletions

View File

@ -28,7 +28,7 @@ import (
// - Errors with Aborted if the requested Version does not match the stored Version. // - Errors with Aborted if the requested Version does not match the stored Version.
// - Errors with PermissionDenied if ACL check fails // - Errors with PermissionDenied if ACL check fails
func (s *Server) Delete(ctx context.Context, req *pbresource.DeleteRequest) (*pbresource.DeleteResponse, error) { func (s *Server) Delete(ctx context.Context, req *pbresource.DeleteRequest) (*pbresource.DeleteResponse, error) {
reg, err := s.validateDeleteRequest(req) reg, err := s.ensureDeleteRequestValid(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -171,7 +171,7 @@ func (s *Server) maybeCreateTombstone(ctx context.Context, deleteId *pbresource.
} }
} }
func (s *Server) validateDeleteRequest(req *pbresource.DeleteRequest) (*resource.Registration, error) { func (s *Server) ensureDeleteRequestValid(req *pbresource.DeleteRequest) (*resource.Registration, error) {
if req.Id == nil { if req.Id == nil {
return nil, status.Errorf(codes.InvalidArgument, "id is required") return nil, status.Errorf(codes.InvalidArgument, "id is required")
} }

View File

@ -16,7 +16,7 @@ import (
) )
func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbresource.ListResponse, error) { func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbresource.ListResponse, error) {
reg, err := s.validateListRequest(req) reg, err := s.ensureListRequestValid(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -81,7 +81,7 @@ func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbreso
return &pbresource.ListResponse{Resources: result}, nil return &pbresource.ListResponse{Resources: result}, nil
} }
func (s *Server) validateListRequest(req *pbresource.ListRequest) (*resource.Registration, error) { func (s *Server) ensureListRequestValid(req *pbresource.ListRequest) (*resource.Registration, error) {
var field string var field string
switch { switch {
case req.Type == nil: case req.Type == nil:

View File

@ -15,7 +15,7 @@ import (
) )
func (s *Server) ListByOwner(ctx context.Context, req *pbresource.ListByOwnerRequest) (*pbresource.ListByOwnerResponse, error) { func (s *Server) ListByOwner(ctx context.Context, req *pbresource.ListByOwnerRequest) (*pbresource.ListByOwnerResponse, error) {
reg, err := s.validateListByOwnerRequest(req) reg, err := s.ensureListByOwnerRequestValid(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -87,7 +87,7 @@ func (s *Server) ListByOwner(ctx context.Context, req *pbresource.ListByOwnerReq
return &pbresource.ListByOwnerResponse{Resources: result}, nil return &pbresource.ListByOwnerResponse{Resources: result}, nil
} }
func (s *Server) validateListByOwnerRequest(req *pbresource.ListByOwnerRequest) (*resource.Registration, error) { func (s *Server) ensureListByOwnerRequestValid(req *pbresource.ListByOwnerRequest) (*resource.Registration, error) {
if req.Owner == nil { if req.Owner == nil {
return nil, status.Errorf(codes.InvalidArgument, "owner is required") return nil, status.Errorf(codes.InvalidArgument, "owner is required")
} }

View File

@ -18,7 +18,7 @@ import (
func (s *Server) Read(ctx context.Context, req *pbresource.ReadRequest) (*pbresource.ReadResponse, error) { func (s *Server) Read(ctx context.Context, req *pbresource.ReadRequest) (*pbresource.ReadResponse, error) {
// Light first pass validation based on what user passed in and not much more. // Light first pass validation based on what user passed in and not much more.
reg, err := s.validateReadRequest(req) reg, err := s.ensureReadRequestValid(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -87,7 +87,7 @@ func (s *Server) Read(ctx context.Context, req *pbresource.ReadRequest) (*pbreso
return &pbresource.ReadResponse{Resource: resource}, nil return &pbresource.ReadResponse{Resource: resource}, nil
} }
func (s *Server) validateReadRequest(req *pbresource.ReadRequest) (*resource.Registration, error) { func (s *Server) ensureReadRequestValid(req *pbresource.ReadRequest) (*resource.Registration, error) {
if req.Id == nil { if req.Id == nil {
return nil, status.Errorf(codes.InvalidArgument, "id is required") return nil, status.Errorf(codes.InvalidArgument, "id is required")
} }
@ -107,31 +107,9 @@ func (s *Server) validateReadRequest(req *pbresource.ReadRequest) (*resource.Reg
} }
// Check scope // Check scope
if reg.Scope == resource.ScopePartition && req.Id.Tenancy.Namespace != "" { if err = validateScopedTenancy(reg.Scope, req.Id.Type, req.Id.Tenancy); err != nil {
return nil, status.Errorf( return nil, err
codes.InvalidArgument,
"partition scoped resource %s cannot have a namespace. got: %s",
resource.ToGVK(req.Id.Type),
req.Id.Tenancy.Namespace,
)
}
if reg.Scope == resource.ScopeCluster {
if req.Id.Tenancy.Partition != "" {
return nil, status.Errorf(
codes.InvalidArgument,
"cluster scoped resource %s cannot have a partition: %s",
resource.ToGVK(req.Id.Type),
req.Id.Tenancy.Partition,
)
}
if req.Id.Tenancy.Namespace != "" {
return nil, status.Errorf(
codes.InvalidArgument,
"cluster scoped resource %s cannot have a namespace: %s",
resource.ToGVK(req.Id.Type),
req.Id.Tenancy.Namespace,
)
}
} }
return reg, nil return reg, nil
} }

View File

@ -242,6 +242,36 @@ func tenancyExists(reg *resource.Registration, tenancyBridge TenancyBridge, tena
return nil return nil
} }
func validateScopedTenancy(scope resource.Scope, resourceType *pbresource.Type, tenancy *pbresource.Tenancy) error {
if scope == resource.ScopePartition && tenancy.Namespace != "" {
return status.Errorf(
codes.InvalidArgument,
"partition scoped resource %s cannot have a namespace. got: %s",
resource.ToGVK(resourceType),
tenancy.Namespace,
)
}
if scope == resource.ScopeCluster {
if tenancy.Partition != "" {
return status.Errorf(
codes.InvalidArgument,
"cluster scoped resource %s cannot have a partition: %s",
resource.ToGVK(resourceType),
tenancy.Partition,
)
}
if tenancy.Namespace != "" {
return status.Errorf(
codes.InvalidArgument,
"cluster scoped resource %s cannot have a namespace: %s",
resource.ToGVK(resourceType),
tenancy.Namespace,
)
}
}
return nil
}
// tenancyMarkedForDeletion returns a gRPC InvalidArgument when either partition or namespace is marked for deletion. // tenancyMarkedForDeletion returns a gRPC InvalidArgument when either partition or namespace is marked for deletion.
func tenancyMarkedForDeletion(reg *resource.Registration, tenancyBridge TenancyBridge, tenancy *pbresource.Tenancy) error { func tenancyMarkedForDeletion(reg *resource.Registration, tenancyBridge TenancyBridge, tenancy *pbresource.Tenancy) error {
if reg.Scope == resource.ScopePartition || reg.Scope == resource.ScopeNamespace { if reg.Scope == resource.ScopePartition || reg.Scope == resource.ScopeNamespace {

View File

@ -16,7 +16,7 @@ import (
) )
func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.ResourceService_WatchListServer) error { func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.ResourceService_WatchListServer) error {
reg, err := s.validateWatchListRequest(req) reg, err := s.ensureWatchListRequestValid(req)
if err != nil { if err != nil {
return err return err
} }
@ -91,17 +91,9 @@ func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.R
} }
} }
func (s *Server) validateWatchListRequest(req *pbresource.WatchListRequest) (*resource.Registration, error) { func (s *Server) ensureWatchListRequestValid(req *pbresource.WatchListRequest) (*resource.Registration, error) {
var field string if req.Type == nil {
switch { return nil, status.Errorf(codes.InvalidArgument, "type is required")
case req.Type == nil:
field = "type"
case req.Tenancy == nil:
field = "tenancy"
}
if field != "" {
return nil, status.Errorf(codes.InvalidArgument, "%s is required", field)
} }
// Check type exists. // Check type exists.
@ -110,6 +102,11 @@ func (s *Server) validateWatchListRequest(req *pbresource.WatchListRequest) (*re
return nil, err return nil, err
} }
// if no tenancy is passed defaults to wildcard
if req.Tenancy == nil {
req.Tenancy = wildcardTenancyFor(reg.Scope)
}
if err = checkV2Tenancy(s.UseV2Tenancy, req.Type); err != nil { if err = checkV2Tenancy(s.UseV2Tenancy, req.Type); err != nil {
return nil, err return nil, err
} }
@ -118,15 +115,33 @@ func (s *Server) validateWatchListRequest(req *pbresource.WatchListRequest) (*re
return nil, err return nil, err
} }
// Error when partition scoped and namespace not empty. // Check scope
if reg.Scope == resource.ScopePartition && req.Tenancy.Namespace != "" { if err = validateScopedTenancy(reg.Scope, req.Type, req.Tenancy); err != nil {
return nil, status.Errorf( return nil, err
codes.InvalidArgument,
"partition scoped type %s cannot have a namespace. got: %s",
resource.ToGVK(req.Type),
req.Tenancy.Namespace,
)
} }
return reg, nil return reg, nil
} }
func wildcardTenancyFor(scope resource.Scope) *pbresource.Tenancy {
var defaultTenancy *pbresource.Tenancy
switch scope {
case resource.ScopeCluster:
defaultTenancy = &pbresource.Tenancy{
PeerName: storage.Wildcard,
}
case resource.ScopePartition:
defaultTenancy = &pbresource.Tenancy{
Partition: storage.Wildcard,
PeerName: storage.Wildcard,
}
default:
defaultTenancy = &pbresource.Tenancy{
Partition: storage.Wildcard,
PeerName: storage.Wildcard,
Namespace: storage.Wildcard,
}
}
return defaultTenancy
}

View File

@ -40,10 +40,6 @@ func TestWatchList_InputValidation(t *testing.T) {
modFn: func(req *pbresource.WatchListRequest) { req.Type = nil }, modFn: func(req *pbresource.WatchListRequest) { req.Type = nil },
errContains: "type is required", errContains: "type is required",
}, },
"no tenancy": {
modFn: func(req *pbresource.WatchListRequest) { req.Tenancy = nil },
errContains: "tenancy is required",
},
"partition mixed case": { "partition mixed case": {
modFn: func(req *pbresource.WatchListRequest) { req.Tenancy.Partition = "Default" }, modFn: func(req *pbresource.WatchListRequest) { req.Tenancy.Partition = "Default" },
errContains: "tenancy.partition invalid", errContains: "tenancy.partition invalid",
@ -75,6 +71,20 @@ func TestWatchList_InputValidation(t *testing.T) {
}, },
errContains: "cannot have a namespace", errContains: "cannot have a namespace",
}, },
"cluster scope with non-empty partition": {
modFn: func(req *pbresource.WatchListRequest) {
req.Type = demo.TypeV1Executive
req.Tenancy = &pbresource.Tenancy{Partition: "bad"}
},
errContains: "cannot have a partition",
},
"cluster scope with non-empty namespace": {
modFn: func(req *pbresource.WatchListRequest) {
req.Type = demo.TypeV1Executive
req.Tenancy = &pbresource.Tenancy{Namespace: "bad"}
},
errContains: "cannot have a namespace",
},
} }
for desc, tc := range testCases { for desc, tc := range testCases {
t.Run(desc, func(t *testing.T) { t.Run(desc, func(t *testing.T) {
@ -382,3 +392,30 @@ type resourceOrError struct {
rsp *pbresource.WatchEvent rsp *pbresource.WatchEvent
err error err error
} }
func TestWatchList_NoTenancy(t *testing.T) {
t.Parallel()
ctx := context.Background()
server := testServer(t)
client := testClient(t, server)
demo.RegisterTypes(server.Registry)
// Create a watch.
stream, err := client.WatchList(ctx, &pbresource.WatchListRequest{
Type: demo.TypeV1RecordLabel,
})
require.NoError(t, err)
rspCh := handleResourceStream(t, stream)
recordLabel, err := demo.GenerateV1RecordLabel("looney-tunes")
require.NoError(t, err)
// Create and verify upsert event received.
recordLabel, err = server.Backend.WriteCAS(ctx, recordLabel)
require.NoError(t, err)
rsp := mustGetResource(t, rspCh)
require.Equal(t, pbresource.WatchEvent_OPERATION_UPSERT, rsp.Operation)
prototest.AssertDeepEqual(t, recordLabel, rsp.Resource)
}

View File

@ -37,7 +37,7 @@ import (
var errUseWriteStatus = status.Error(codes.InvalidArgument, "resource.status can only be set using the WriteStatus endpoint") var errUseWriteStatus = status.Error(codes.InvalidArgument, "resource.status can only be set using the WriteStatus endpoint")
func (s *Server) Write(ctx context.Context, req *pbresource.WriteRequest) (*pbresource.WriteResponse, error) { func (s *Server) Write(ctx context.Context, req *pbresource.WriteRequest) (*pbresource.WriteResponse, error) {
reg, err := s.validateWriteRequest(req) reg, err := s.ensureWriteRequestValid(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -265,7 +265,7 @@ func (s *Server) retryCAS(ctx context.Context, vsn string, cas func() error) err
return err return err
} }
func (s *Server) validateWriteRequest(req *pbresource.WriteRequest) (*resource.Registration, error) { func (s *Server) ensureWriteRequestValid(req *pbresource.WriteRequest) (*resource.Registration, error) {
var field string var field string
switch { switch {
case req.Resource == nil: case req.Resource == nil:

View File

@ -242,6 +242,86 @@ func TestController_NoReconciler(t *testing.T) {
func() { mgr.Register(ctrl) }) func() { mgr.Register(ctrl) })
} }
func TestController_Watch(t *testing.T) {
t.Parallel()
t.Run("partitioned scoped resources", func(t *testing.T) {
rec := newTestReconciler()
client := svctest.RunResourceService(t, demo.RegisterTypes)
ctrl := controller.
ForType(demo.TypeV1RecordLabel).
WithReconciler(rec)
mgr := controller.NewManager(client, testutil.Logger(t))
mgr.SetRaftLeader(true)
mgr.Register(ctrl)
ctx := testContext(t)
go mgr.Run(ctx)
res, err := demo.GenerateV1RecordLabel("test")
require.NoError(t, err)
rsp, err := client.Write(testContext(t), &pbresource.WriteRequest{Resource: res})
require.NoError(t, err)
req := rec.wait(t)
prototest.AssertDeepEqual(t, rsp.Resource.Id, req.ID)
})
t.Run("cluster scoped resources", func(t *testing.T) {
rec := newTestReconciler()
client := svctest.RunResourceService(t, demo.RegisterTypes)
ctrl := controller.
ForType(demo.TypeV1Executive).
WithReconciler(rec)
mgr := controller.NewManager(client, testutil.Logger(t))
mgr.SetRaftLeader(true)
mgr.Register(ctrl)
go mgr.Run(testContext(t))
exec, err := demo.GenerateV1Executive("test", "CEO")
require.NoError(t, err)
rsp, err := client.Write(testContext(t), &pbresource.WriteRequest{Resource: exec})
require.NoError(t, err)
req := rec.wait(t)
prototest.AssertDeepEqual(t, rsp.Resource.Id, req.ID)
})
t.Run("namespace scoped resources", func(t *testing.T) {
rec := newTestReconciler()
client := svctest.RunResourceService(t, demo.RegisterTypes)
ctrl := controller.
ForType(demo.TypeV2Artist).
WithReconciler(rec)
mgr := controller.NewManager(client, testutil.Logger(t))
mgr.SetRaftLeader(true)
mgr.Register(ctrl)
go mgr.Run(testContext(t))
artist, err := demo.GenerateV2Artist()
require.NoError(t, err)
rsp, err := client.Write(testContext(t), &pbresource.WriteRequest{Resource: artist})
require.NoError(t, err)
req := rec.wait(t)
prototest.AssertDeepEqual(t, rsp.Resource.Id, req.ID)
})
}
func newTestReconciler() *testReconciler { func newTestReconciler() *testReconciler {
return &testReconciler{ return &testReconciler{
calls: make(chan controller.Request), calls: make(chan controller.Request),

View File

@ -14,7 +14,6 @@ import (
"github.com/hashicorp/consul/agent/consul/controller/queue" "github.com/hashicorp/consul/agent/consul/controller/queue"
"github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/internal/resource"
"github.com/hashicorp/consul/internal/storage"
"github.com/hashicorp/consul/proto-public/pbresource" "github.com/hashicorp/consul/proto-public/pbresource"
) )
@ -92,11 +91,6 @@ func runQueue[T queue.ItemType](ctx context.Context, ctrl Controller) queue.Work
func (c *controllerRunner) watch(ctx context.Context, typ *pbresource.Type, add func(*pbresource.Resource)) error { func (c *controllerRunner) watch(ctx context.Context, typ *pbresource.Type, add func(*pbresource.Resource)) error {
wl, err := c.client.WatchList(ctx, &pbresource.WatchListRequest{ wl, err := c.client.WatchList(ctx, &pbresource.WatchListRequest{
Type: typ, Type: typ,
Tenancy: &pbresource.Tenancy{
Partition: storage.Wildcard,
PeerName: storage.Wildcard,
Namespace: storage.Wildcard,
},
}) })
if err != nil { if err != nil {
c.logger.Error("failed to create watch", "error", err) c.logger.Error("failed to create watch", "error", err)