diff --git a/agent/consul/state/catalog.go b/agent/consul/state/catalog.go index c26d10f603..eefdb82e31 100644 --- a/agent/consul/state/catalog.go +++ b/agent/consul/state/catalog.go @@ -202,6 +202,8 @@ func (s *Store) EnsureNode(idx uint64, node *structs.Node) error { func ensureNoNodeWithSimilarNameTxn(tx ReadTxn, node *structs.Node, allowClashWithoutID bool) error { // Retrieve all of the nodes + // TODO(partitions): since the node UUID field is not partitioned, do we have to do something additional here? + enodes, err := tx.Get(tableNodes, indexID+"_prefix", node.GetEnterpriseMeta()) if err != nil { return fmt.Errorf("Cannot lookup all nodes: %s", err) diff --git a/agent/consul/state/catalog_events.go b/agent/consul/state/catalog_events.go index 97be4638bc..c0c2fecde4 100644 --- a/agent/consul/state/catalog_events.go +++ b/agent/consul/state/catalog_events.go @@ -25,14 +25,15 @@ type EventPayloadCheckServiceNode struct { // when the change event is for a sidecar or gateway. overrideKey string overrideNamespace string + overridePartition string } func (e EventPayloadCheckServiceNode) HasReadPermission(authz acl.Authorizer) bool { return e.Value.CanRead(authz) == acl.Allow } -func (e EventPayloadCheckServiceNode) MatchesKey(key, namespace string) bool { - if key == "" && namespace == "" { +func (e EventPayloadCheckServiceNode) MatchesKey(key, namespace, partition string) bool { + if key == "" && namespace == "" && partition == "" { return true } @@ -48,8 +49,14 @@ func (e EventPayloadCheckServiceNode) MatchesKey(key, namespace string) bool { if e.overrideNamespace != "" { ns = e.overrideNamespace } + ap := e.Value.Service.EnterpriseMeta.PartitionOrDefault() + if e.overridePartition != "" { + ap = e.overridePartition + } + return (key == "" || strings.EqualFold(key, name)) && - (namespace == "" || strings.EqualFold(namespace, ns)) + (namespace == "" || strings.EqualFold(namespace, ns)) && + (partition == "" || strings.EqualFold(partition, ap)) } // serviceHealthSnapshot returns a stream.SnapshotFunc that provides a snapshot @@ -60,7 +67,7 @@ func serviceHealthSnapshot(db ReadDB, topic stream.Topic) stream.SnapshotFunc { defer tx.Abort() connect := topic == topicServiceHealthConnect - entMeta := structs.NewEnterpriseMetaInDefaultPartition(req.Namespace) + entMeta := structs.NewEnterpriseMetaWithPartition(req.Partition, req.Namespace) idx, nodes, err := checkServiceNodesTxn(tx, nil, req.Key, connect, &entMeta) if err != nil { return 0, err @@ -123,6 +130,11 @@ type serviceChange struct { change memdb.Change } +type nodeTuple struct { + Node string + Partition string +} + var serviceChangeIndirect = serviceChange{changeType: changeIndirect} // ServiceHealthEventsFromChanges returns all the service and Connect health @@ -130,13 +142,13 @@ var serviceChangeIndirect = serviceChange{changeType: changeIndirect} func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) { var events []stream.Event - var nodeChanges map[string]changeType + var nodeChanges map[nodeTuple]changeType var serviceChanges map[nodeServiceTuple]serviceChange var termGatewayChanges map[structs.ServiceName]map[structs.ServiceName]serviceChange - markNode := func(node string, typ changeType) { + markNode := func(node nodeTuple, typ changeType) { if nodeChanges == nil { - nodeChanges = make(map[string]changeType) + nodeChanges = make(map[nodeTuple]changeType) } // If the caller has an actual node mutation ensure we store it even if the // node is already marked. If the caller is just marking the node dirty @@ -161,14 +173,15 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event for _, change := range changes.Changes { switch change.Table { - case "nodes": + case tableNodes: // Node changed in some way, if it's not a delete, we'll need to // re-deliver CheckServiceNode results for all services on that node but // we mark it anyway because if it _is_ a delete then we need to know that // later to avoid trying to deliver events when node level checks mark the // node as "changed". n := changeObject(change).(*structs.Node) - markNode(n.Node, changeTypeFromChange(change)) + tuple := newNodeTupleFromNode(n) + markNode(tuple, changeTypeFromChange(change)) case tableServices: sn := changeObject(change).(*structs.ServiceNode) @@ -187,7 +200,8 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event after := change.After.(*structs.HealthCheck) if after.ServiceID == "" || before.ServiceID == "" { // check before and/or after is node-scoped - markNode(after.Node, changeIndirect) + nt := newNodeTupleFromHealthCheck(after) + markNode(nt, changeIndirect) } else { // Check changed which means we just need to emit for the linked // service. @@ -206,7 +220,8 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event obj := changeObject(change).(*structs.HealthCheck) if obj.ServiceID == "" { // Node level check - markNode(obj.Node, changeIndirect) + nt := newNodeTupleFromHealthCheck(obj) + markNode(nt, changeIndirect) } else { markService(newNodeServiceTupleFromServiceHealthCheck(obj), serviceChangeIndirect) } @@ -250,7 +265,8 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event continue } // Rebuild events for all services on this node - es, err := newServiceHealthEventsForNode(tx, changes.Index, node) + es, err := newServiceHealthEventsForNode(tx, changes.Index, node.Node, + structs.WildcardEnterpriseMetaInPartition(node.Partition)) if err != nil { return nil, err } @@ -286,7 +302,7 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event } } - if _, ok := nodeChanges[tuple.Node]; ok { + if _, ok := nodeChanges[tuple.nodeTuple()]; ok { // We already rebuilt events for everything on this node, no need to send // a duplicate. continue @@ -303,7 +319,10 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event for serviceName, gsChange := range serviceChanges { gs := changeObject(gsChange.change).(*structs.GatewayService) - q := Query{Value: gs.Gateway.Name, EnterpriseMeta: gatewayName.EnterpriseMeta} + q := Query{ + Value: gs.Gateway.Name, + EnterpriseMeta: gatewayName.EnterpriseMeta, + } _, nodes, err := serviceNodesTxn(tx, nil, indexService, q) if err != nil { return nil, err @@ -320,6 +339,9 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event if gatewayName.EnterpriseMeta.NamespaceOrDefault() != serviceName.EnterpriseMeta.NamespaceOrDefault() { payload.overrideNamespace = serviceName.EnterpriseMeta.NamespaceOrDefault() } + if gatewayName.EnterpriseMeta.PartitionOrDefault() != serviceName.EnterpriseMeta.PartitionOrDefault() { + payload.overridePartition = serviceName.EnterpriseMeta.PartitionOrDefault() + } e.Payload = payload events = append(events, e) @@ -344,6 +366,9 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event if gatewayName.EnterpriseMeta.NamespaceOrDefault() != serviceName.EnterpriseMeta.NamespaceOrDefault() { payload.overrideNamespace = serviceName.EnterpriseMeta.NamespaceOrDefault() } + if gatewayName.EnterpriseMeta.PartitionOrDefault() != serviceName.EnterpriseMeta.PartitionOrDefault() { + payload.overridePartition = serviceName.EnterpriseMeta.PartitionOrDefault() + } e.Payload = payload events = append(events, e) @@ -480,6 +505,9 @@ func copyEventForService(event stream.Event, service structs.ServiceName) stream if payload.Value.Service.EnterpriseMeta.NamespaceOrDefault() != service.EnterpriseMeta.NamespaceOrDefault() { payload.overrideNamespace = service.EnterpriseMeta.NamespaceOrDefault() } + if payload.Value.Service.EnterpriseMeta.PartitionOrDefault() != service.EnterpriseMeta.PartitionOrDefault() { + payload.overridePartition = service.EnterpriseMeta.PartitionOrDefault() + } event.Payload = payload return event @@ -497,13 +525,16 @@ func getPayloadCheckServiceNode(payload stream.Payload) *structs.CheckServiceNod // given node. This mirrors some of the the logic in the oddly-named // parseCheckServiceNodes but is more efficient since we know they are all on // the same node. -func newServiceHealthEventsForNode(tx ReadTxn, idx uint64, node string) ([]stream.Event, error) { - services, err := tx.Get(tableServices, indexNode, Query{Value: node}) +func newServiceHealthEventsForNode(tx ReadTxn, idx uint64, node string, entMeta *structs.EnterpriseMeta) ([]stream.Event, error) { + services, err := tx.Get(tableServices, indexNode, Query{ + Value: node, + EnterpriseMeta: *entMeta, + }) if err != nil { return nil, err } - n, checksFunc, err := getNodeAndChecks(tx, node) + n, checksFunc, err := getNodeAndChecks(tx, node, entMeta) if err != nil { return nil, err } @@ -521,9 +552,12 @@ func newServiceHealthEventsForNode(tx ReadTxn, idx uint64, node string) ([]strea // getNodeAndNodeChecks returns a the node structure and a function that returns // the full list of checks for a specific service on that node. -func getNodeAndChecks(tx ReadTxn, node string) (*structs.Node, serviceChecksFunc, error) { +func getNodeAndChecks(tx ReadTxn, node string, entMeta *structs.EnterpriseMeta) (*structs.Node, serviceChecksFunc, error) { // Fetch the node - nodeRaw, err := tx.First(tableNodes, indexID, Query{Value: node}) + nodeRaw, err := tx.First(tableNodes, indexID, Query{ + Value: node, + EnterpriseMeta: *entMeta, + }) if err != nil { return nil, nil, err } @@ -532,7 +566,10 @@ func getNodeAndChecks(tx ReadTxn, node string) (*structs.Node, serviceChecksFunc } n := nodeRaw.(*structs.Node) - iter, err := tx.Get(tableChecks, indexNode, Query{Value: node}) + iter, err := tx.Get(tableChecks, indexNode, Query{ + Value: node, + EnterpriseMeta: *entMeta, + }) if err != nil { return nil, nil, err } @@ -566,12 +603,16 @@ func getNodeAndChecks(tx ReadTxn, node string) (*structs.Node, serviceChecksFunc type serviceChecksFunc func(serviceID string) structs.HealthChecks func newServiceHealthEventForService(tx ReadTxn, idx uint64, tuple nodeServiceTuple) (stream.Event, error) { - n, checksFunc, err := getNodeAndChecks(tx, tuple.Node) + n, checksFunc, err := getNodeAndChecks(tx, tuple.Node, &tuple.EntMeta) if err != nil { return stream.Event{}, err } - svc, err := tx.Get(tableServices, indexID, NodeServiceQuery{EnterpriseMeta: tuple.EntMeta, Node: tuple.Node, Service: tuple.ServiceID}) + svc, err := tx.Get(tableServices, indexID, NodeServiceQuery{ + EnterpriseMeta: tuple.EntMeta, + Node: tuple.Node, + Service: tuple.ServiceID, + }) if err != nil { return stream.Event{}, err } @@ -615,9 +656,14 @@ func newServiceHealthEventDeregister(idx uint64, sn *structs.ServiceNode) stream // This is also important because if the service was deleted as part of a // whole node deregistering then the node record won't actually exist now // anyway and we'd have to plumb it through from the changeset above. + + entMeta := sn.EnterpriseMeta + entMeta.Normalize() + csn := &structs.CheckServiceNode{ Node: &structs.Node{ - Node: sn.Node, + Node: sn.Node, + Partition: entMeta.PartitionOrEmpty(), }, Service: sn.ToNodeService(), } diff --git a/agent/consul/state/catalog_events_oss.go b/agent/consul/state/catalog_events_oss.go new file mode 100644 index 0000000000..cf5231dc9a --- /dev/null +++ b/agent/consul/state/catalog_events_oss.go @@ -0,0 +1,23 @@ +// +build !consulent + +package state + +import "github.com/hashicorp/consul/agent/structs" + +func (nst nodeServiceTuple) nodeTuple() nodeTuple { + return nodeTuple{Node: nst.Node, Partition: ""} +} + +func newNodeTupleFromNode(node *structs.Node) nodeTuple { + return nodeTuple{ + Node: node.Node, + Partition: "", + } +} + +func newNodeTupleFromHealthCheck(hc *structs.HealthCheck) nodeTuple { + return nodeTuple{ + Node: hc.Node, + Partition: "", + } +} diff --git a/agent/consul/state/catalog_events_test.go b/agent/consul/state/catalog_events_test.go index 558f63e427..277dec11c3 100644 --- a/agent/consul/state/catalog_events_test.go +++ b/agent/consul/state/catalog_events_test.go @@ -1605,9 +1605,9 @@ func (tc eventsTestCase) run(t *testing.T) { assertDeepEqual(t, tc.WantEvents, got, cmpPartialOrderEvents, cmpopts.EquateEmpty()) } -func runCase(t *testing.T, name string, fn func(t *testing.T)) { +func runCase(t *testing.T, name string, fn func(t *testing.T)) bool { t.Helper() - t.Run(name, func(t *testing.T) { + return t.Run(name, func(t *testing.T) { t.Helper() t.Log("case:", name) fn(t) @@ -1680,7 +1680,11 @@ var cmpPartialOrderEvents = cmp.Options{ if payload.overrideNamespace != "" { ns = payload.overrideNamespace } - return fmt.Sprintf("%s/%s/%s/%s", e.Topic, csn.Node.Node, ns, name) + ap := csn.Service.EnterpriseMeta.PartitionOrDefault() + if payload.overridePartition != "" { + ap = payload.overridePartition + } + return fmt.Sprintf("%s/%s/%s/%s/%s", e.Topic, ap, csn.Node.Node, ns, name) } return key(i) < key(j) }), @@ -2172,6 +2176,7 @@ func newTestEventServiceHealthRegister(index uint64, nodeNum int, svc string) st Node: node, Address: addr, Datacenter: "dc1", + Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(), RaftIndex: structs.RaftIndex{ CreateIndex: index, ModifyIndex: index, @@ -2238,7 +2243,8 @@ func newTestEventServiceHealthDeregister(index uint64, nodeNum int, svc string) Op: pbsubscribe.CatalogOp_Deregister, Value: &structs.CheckServiceNode{ Node: &structs.Node{ - Node: fmt.Sprintf("node%d", nodeNum), + Node: fmt.Sprintf("node%d", nodeNum), + Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(), }, Service: &structs.NodeService{ ID: svc, @@ -2270,6 +2276,7 @@ func TestEventPayloadCheckServiceNode_FilterByKey(t *testing.T) { payload EventPayloadCheckServiceNode key string namespace string + partition string // TODO(partitions): create test cases for this being set expected bool } @@ -2278,7 +2285,7 @@ func TestEventPayloadCheckServiceNode_FilterByKey(t *testing.T) { t.Skip("cant test namespace matching without namespace support") } - require.Equal(t, tc.expected, tc.payload.MatchesKey(tc.key, tc.namespace)) + require.Equal(t, tc.expected, tc.payload.MatchesKey(tc.key, tc.namespace, tc.partition)) } var testCases = []testCase{ diff --git a/agent/consul/state/catalog_test.go b/agent/consul/state/catalog_test.go index d307ef56ba..c4b09aeae3 100644 --- a/agent/consul/state/catalog_test.go +++ b/agent/consul/state/catalog_test.go @@ -28,6 +28,7 @@ func makeRandomNodeID(t *testing.T) types.NodeID { func TestStateStore_GetNodeID(t *testing.T) { s := testStateStore(t) + _, out, err := s.GetNodeID(types.NodeID("wrongId")) if err == nil || out != nil || !strings.Contains(err.Error(), "node lookup by ID failed, wrong UUID") { t.Fatalf("want an error, nil value, err:=%q ; out:=%q", err.Error(), out) @@ -53,30 +54,53 @@ func TestStateStore_GetNodeID(t *testing.T) { Node: "node1", Address: "1.2.3.4", } - if err := s.EnsureRegistration(1, req); err != nil { - t.Fatalf("err: %s", err) - } + require.NoError(t, s.EnsureRegistration(1, req)) _, out, err = s.GetNodeID(nodeID) - if err != nil { - t.Fatalf("got err %s want nil", err) - } + require.NoError(t, err) if out == nil || out.ID != nodeID { t.Fatalf("out should not be nil and contain nodeId, but was:=%#v", out) } + // Case insensitive lookup should work as well _, out, err = s.GetNodeID(types.NodeID("00a916bC-a357-4a19-b886-59419fceeAAA")) - if err != nil { - t.Fatalf("got err %s want nil", err) - } + require.NoError(t, err) if out == nil || out.ID != nodeID { t.Fatalf("out should not be nil and contain nodeId, but was:=%#v", out) } } +func TestStateStore_GetNode(t *testing.T) { + s := testStateStore(t) + + // initially does not exist + idx, out, err := s.GetNode("node1", nil) + require.NoError(t, err) + require.Nil(t, out) + require.Equal(t, uint64(0), idx) + + // Create it + testRegisterNode(t, s, 1, "node1") + + // now exists + idx, out, err = s.GetNode("node1", nil) + require.NoError(t, err) + require.NotNil(t, out) + require.Equal(t, uint64(1), idx) + require.Equal(t, "node1", out.Node) + + // Case insensitive lookup should work as well + idx, out, err = s.GetNode("NoDe1", nil) + require.NoError(t, err) + require.NotNil(t, out) + require.Equal(t, uint64(1), idx) + require.Equal(t, "node1", out.Node) +} + func TestStateStore_ensureNoNodeWithSimilarNameTxn(t *testing.T) { t.Parallel() s := testStateStore(t) + nodeID := makeRandomNodeID(t) req := &structs.RegisterRequest{ ID: nodeID, @@ -90,9 +114,7 @@ func TestStateStore_ensureNoNodeWithSimilarNameTxn(t *testing.T) { Status: api.HealthPassing, }, } - if err := s.EnsureRegistration(1, req); err != nil { - t.Fatalf("err: %s", err) - } + require.NoError(t, s.EnsureRegistration(1, req)) req = &structs.RegisterRequest{ ID: types.NodeID(""), Node: "node2", @@ -103,31 +125,29 @@ func TestStateStore_ensureNoNodeWithSimilarNameTxn(t *testing.T) { Status: api.HealthPassing, }, } - if err := s.EnsureRegistration(2, req); err != nil { - t.Fatalf("err: %s", err) - } + require.NoError(t, s.EnsureRegistration(2, req)) + tx := s.db.WriteTxnRestore() defer tx.Abort() + node := &structs.Node{ ID: makeRandomNodeID(t), Node: "NOdE1", // Name is similar but case is different Address: "2.3.4.5", } + // Lets conflict with node1 (has an ID) - if err := ensureNoNodeWithSimilarNameTxn(tx, node, false); err == nil { - t.Fatalf("Should return an error since another name with similar name exists") - } - if err := ensureNoNodeWithSimilarNameTxn(tx, node, true); err == nil { - t.Fatalf("Should return an error since another name with similar name exists") - } + require.Error(t, ensureNoNodeWithSimilarNameTxn(tx, node, false), + "Should return an error since another name with similar name exists") + require.Error(t, ensureNoNodeWithSimilarNameTxn(tx, node, true), + "Should return an error since another name with similar name exists") + // Lets conflict with node without ID node.Node = "NoDe2" - if err := ensureNoNodeWithSimilarNameTxn(tx, node, false); err == nil { - t.Fatalf("Should return an error since another name with similar name exists") - } - if err := ensureNoNodeWithSimilarNameTxn(tx, node, true); err != nil { - t.Fatalf("Should not clash with another similar node name without ID, err:=%q", err) - } + require.Error(t, ensureNoNodeWithSimilarNameTxn(tx, node, false), + "Should return an error since another name with similar name exists") + require.NoError(t, ensureNoNodeWithSimilarNameTxn(tx, node, true), + "Should not clash with another similar node name without ID") // Set node1's Serf health to failing and replace it. newNode := &structs.Node{ @@ -135,17 +155,15 @@ func TestStateStore_ensureNoNodeWithSimilarNameTxn(t *testing.T) { Node: "node1", Address: "2.3.4.5", } - if err := ensureNoNodeWithSimilarNameTxn(tx, newNode, false); err == nil { - t.Fatalf("Should return an error since the previous node is still healthy") - } - s.ensureCheckTxn(tx, 5, false, &structs.HealthCheck{ + require.Error(t, ensureNoNodeWithSimilarNameTxn(tx, newNode, false), + "Should return an error since the previous node is still healthy") + + require.NoError(t, s.ensureCheckTxn(tx, 5, false, &structs.HealthCheck{ Node: "node1", CheckID: structs.SerfCheckID, Status: api.HealthCritical, - }) - if err := ensureNoNodeWithSimilarNameTxn(tx, newNode, false); err != nil { - t.Fatal(err) - } + })) + require.NoError(t, ensureNoNodeWithSimilarNameTxn(tx, newNode, false)) } func TestStateStore_EnsureRegistration(t *testing.T) { diff --git a/agent/consul/state/state_store_test.go b/agent/consul/state/state_store_test.go index 8e7da5bcab..68e2e08fea 100644 --- a/agent/consul/state/state_store_test.go +++ b/agent/consul/state/state_store_test.go @@ -94,7 +94,10 @@ func testRegisterNodeOpts(t *testing.T, s *Store, idx uint64, nodeID string, opt tx := s.db.Txn(false) defer tx.Abort() - n, err := tx.First(tableNodes, indexID, Query{Value: nodeID}) + n, err := tx.First(tableNodes, indexID, Query{ + Value: nodeID, + EnterpriseMeta: *node.GetEnterpriseMeta(), + }) if err != nil { t.Fatalf("err: %s", err) } diff --git a/agent/consul/state/store_integration_test.go b/agent/consul/state/store_integration_test.go index e31c4158f0..dc6cee8690 100644 --- a/agent/consul/state/store_integration_test.go +++ b/agent/consul/state/store_integration_test.go @@ -422,7 +422,19 @@ type nodePayload struct { node *structs.ServiceNode } -func (p nodePayload) MatchesKey(key, _ string) bool { +func (p nodePayload) MatchesKey(key, _, partition string) bool { + if key == "" && partition == "" { + return true + } + + if p.node == nil { + return false + } + + if structs.PartitionOrDefault(partition) != p.node.PartitionOrDefault() { + return false + } + return p.key == key } diff --git a/agent/consul/stream/event.go b/agent/consul/stream/event.go index 74df46b5e1..285710543e 100644 --- a/agent/consul/stream/event.go +++ b/agent/consul/stream/event.go @@ -26,12 +26,13 @@ type Event struct { // should not modify the state of the payload if the Event is being submitted to // EventPublisher.Publish. type Payload interface { - // MatchesKey must return true if the Payload should be included in a subscription - // requested with the key and namespace. - // Generally this means that the payload matches the key and namespace or - // the payload is a special framing event that should be returned to every - // subscription. - MatchesKey(key, namespace string) bool + // MatchesKey must return true if the Payload should be included in a + // subscription requested with the key, namespace, and partition. + // + // Generally this means that the payload matches the key, namespace, and + // partition or the payload is a special framing event that should be + // returned to every subscription. + MatchesKey(key, namespace, partition string) bool // HasReadPermission uses the acl.Authorizer to determine if the items in the // Payload are visible to the request. It returns true if the payload is @@ -80,10 +81,11 @@ func (p *PayloadEvents) filter(f func(Event) bool) bool { return true } -// MatchesKey filters the PayloadEvents to those which match the key and namespace. -func (p *PayloadEvents) MatchesKey(key, namespace string) bool { +// MatchesKey filters the PayloadEvents to those which match the key, +// namespace, and partition. +func (p *PayloadEvents) MatchesKey(key, namespace, partition string) bool { return p.filter(func(event Event) bool { - return event.Payload.MatchesKey(key, namespace) + return event.Payload.MatchesKey(key, namespace, partition) }) } @@ -115,7 +117,7 @@ func (e Event) IsNewSnapshotToFollow() bool { type framingEvent struct{} -func (framingEvent) MatchesKey(string, string) bool { +func (framingEvent) MatchesKey(string, string, string) bool { return true } @@ -135,7 +137,7 @@ type closeSubscriptionPayload struct { tokensSecretIDs []string } -func (closeSubscriptionPayload) MatchesKey(string, string) bool { +func (closeSubscriptionPayload) MatchesKey(string, string, string) bool { return false } diff --git a/agent/consul/stream/event_publisher.go b/agent/consul/stream/event_publisher.go index bfa3858b96..163fa81094 100644 --- a/agent/consul/stream/event_publisher.go +++ b/agent/consul/stream/event_publisher.go @@ -291,5 +291,5 @@ func (e *EventPublisher) setCachedSnapshotLocked(req *SubscribeRequest, snap *ev } func snapCacheKey(req *SubscribeRequest) string { - return fmt.Sprintf(req.Namespace + "/" + req.Key) + return req.Partition + "/" + req.Namespace + "/" + req.Key } diff --git a/agent/consul/stream/event_publisher_test.go b/agent/consul/stream/event_publisher_test.go index 2967ef8d3d..af7fc3c288 100644 --- a/agent/consul/stream/event_publisher_test.go +++ b/agent/consul/stream/event_publisher_test.go @@ -70,7 +70,7 @@ type simplePayload struct { noReadPerm bool } -func (p simplePayload) MatchesKey(key, _ string) bool { +func (p simplePayload) MatchesKey(key, _, _ string) bool { if key == "" { return true } diff --git a/agent/consul/stream/event_test.go b/agent/consul/stream/event_test.go index 8b36ee8d15..a3187017ca 100644 --- a/agent/consul/stream/event_test.go +++ b/agent/consul/stream/event_test.go @@ -35,7 +35,7 @@ func TestPayloadEvents_FilterByKey(t *testing.T) { events = append(events, tc.events...) pe := &PayloadEvents{Items: events} - ok := pe.MatchesKey(tc.req.Key, tc.req.Namespace) + ok := pe.MatchesKey(tc.req.Key, tc.req.Namespace, tc.req.Partition) require.Equal(t, tc.expectEvent, ok) if !tc.expectEvent { return @@ -133,6 +133,7 @@ func TestPayloadEvents_FilterByKey(t *testing.T) { } } +// TODO(partitions) func newNSEvent(key, namespace string) Event { return Event{Index: 22, Payload: nsPayload{key: key, namespace: namespace}} } @@ -141,11 +142,14 @@ type nsPayload struct { framingEvent key string namespace string + partition string value string } -func (p nsPayload) MatchesKey(key, namespace string) bool { - return (key == "" || key == p.key) && (namespace == "" || namespace == p.namespace) +func (p nsPayload) MatchesKey(key, namespace, partition string) bool { + return (key == "" || key == p.key) && + (namespace == "" || namespace == p.namespace) && + (partition == "" || partition == p.partition) } func TestPayloadEvents_HasReadPermission(t *testing.T) { diff --git a/agent/consul/stream/subscription.go b/agent/consul/stream/subscription.go index 03069ea931..9f47cd2ee3 100644 --- a/agent/consul/stream/subscription.go +++ b/agent/consul/stream/subscription.go @@ -62,6 +62,9 @@ type SubscribeRequest struct { // Namespace used to filter events in the topic. Only events matching the // namespace will be returned by the subscription. Namespace string + // Partition used to filter events in the topic. Only events matching the + // partition will be returned by the subscription. + Partition string // TODO(partitions): make this work // Token that was used to authenticate the request. If any ACL policy // changes impact the token the subscription will be forcefully closed. Token string @@ -102,7 +105,7 @@ func (s *Subscription) Next(ctx context.Context) (Event, error) { continue } event := newEventFromBatch(s.req, next.Events) - if !event.Payload.MatchesKey(s.req.Key, s.req.Namespace) { + if !event.Payload.MatchesKey(s.req.Key, s.req.Namespace, s.req.Partition) { continue } return event, nil