From 1d0f3c4853f8e125c85ee67d58965b59d0246ff9 Mon Sep 17 00:00:00 2001 From: Paul Banks Date: Wed, 25 Mar 2020 16:59:25 +0000 Subject: [PATCH] Server gRPC Subscribe endpoint Includes ACL filter work, and some code-gen changes in agentpb to ensure that we can't accidentally decide to forget to add ACL rules for new types. --- agent/agentpb/common_oss.pb.go | 17 +- agent/agentpb/common_oss.proto | 1 - agent/agentpb/event.go | 79 + agent/agentpb/event_test.go | 114 ++ agent/agentpb/event_types.structgen.go | 23 + agent/agentpb/structgen/structgen.go | 77 +- agent/agentpb/testing_events.go | 236 +++ agent/consul/acl.go | 41 + agent/consul/catalog_endpoint_test.go | 6 +- agent/consul/grpc_service.go | 28 + agent/consul/server.go | 2 +- agent/consul/state/catalog_events_test.go | 114 +- agent/consul/subscribe_grpc_endpoint.go | 191 +++ agent/consul/subscribe_grpc_endpoint_test.go | 1455 +++++++++++++++++ ...test_endpoint.go => test_grpc_endpoint.go} | 0 logging/names.go | 1 + 16 files changed, 2257 insertions(+), 128 deletions(-) create mode 100644 agent/agentpb/event.go create mode 100644 agent/agentpb/event_test.go create mode 100644 agent/agentpb/event_types.structgen.go create mode 100644 agent/agentpb/testing_events.go create mode 100644 agent/consul/grpc_service.go create mode 100644 agent/consul/subscribe_grpc_endpoint.go create mode 100644 agent/consul/subscribe_grpc_endpoint_test.go rename agent/consul/{test_endpoint.go => test_grpc_endpoint.go} (100%) diff --git a/agent/agentpb/common_oss.pb.go b/agent/agentpb/common_oss.pb.go index 888ad06b76..322d40f913 100644 --- a/agent/agentpb/common_oss.pb.go +++ b/agent/agentpb/common_oss.pb.go @@ -6,7 +6,6 @@ package agentpb import ( fmt "fmt" _ "github.com/gogo/protobuf/gogoproto" - _ "github.com/gogo/protobuf/types" proto "github.com/golang/protobuf/proto" io "io" math "math" @@ -66,17 +65,15 @@ func init() { func init() { proto.RegisterFile("common_oss.proto", fileDescriptor_bcf35e841fcc50ea) } var fileDescriptor_bcf35e841fcc50ea = []byte{ - // 147 bytes of a gzipped FileDescriptorProto + // 123 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x48, 0xce, 0xcf, 0xcd, 0xcd, 0xcf, 0x8b, 0xcf, 0x2f, 0x2e, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x4f, 0x4c, - 0x4f, 0xcd, 0x2b, 0x29, 0x48, 0x92, 0x92, 0x4b, 0xcf, 0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07, 0x0b, - 0x27, 0x95, 0xa6, 0xe9, 0xa7, 0x94, 0x16, 0x25, 0x96, 0x64, 0xe6, 0xe7, 0x41, 0x14, 0x4a, 0x89, - 0xa4, 0xe7, 0xa7, 0xe7, 0x83, 0x99, 0xfa, 0x20, 0x16, 0x44, 0x54, 0x49, 0x80, 0x8b, 0xcf, 0x35, - 0xaf, 0x24, 0xb5, 0xa8, 0xa0, 0x28, 0xb3, 0x38, 0xd5, 0x37, 0xb5, 0x24, 0xd1, 0x49, 0xe1, 0xc4, - 0x43, 0x39, 0x86, 0x13, 0x8f, 0xe4, 0x18, 0x2f, 0x3c, 0x92, 0x63, 0x7c, 0xf0, 0x48, 0x8e, 0x71, - 0xc2, 0x63, 0x39, 0x86, 0x19, 0x8f, 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63, 0xb8, 0xf1, 0x58, 0x8e, - 0x21, 0x89, 0x0d, 0xac, 0xd5, 0x18, 0x10, 0x00, 0x00, 0xff, 0xff, 0xbf, 0xbf, 0x18, 0x0f, 0x8d, - 0x00, 0x00, 0x00, + 0x4f, 0xcd, 0x2b, 0x29, 0x48, 0x92, 0x12, 0x49, 0xcf, 0x4f, 0xcf, 0x07, 0x8b, 0xe9, 0x83, 0x58, + 0x10, 0x69, 0x25, 0x01, 0x2e, 0x3e, 0xd7, 0xbc, 0x92, 0xd4, 0xa2, 0x82, 0xa2, 0xcc, 0xe2, 0x54, + 0xdf, 0xd4, 0x92, 0x44, 0x27, 0x85, 0x13, 0x0f, 0xe5, 0x18, 0x4e, 0x3c, 0x92, 0x63, 0xbc, 0xf0, + 0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, 0xc6, 0x09, 0x8f, 0xe5, 0x18, 0x66, 0x3c, 0x96, 0x63, 0xb8, + 0xf0, 0x58, 0x8e, 0xe1, 0xc6, 0x63, 0x39, 0x86, 0x24, 0x36, 0xb0, 0x56, 0x63, 0x40, 0x00, 0x00, + 0x00, 0xff, 0xff, 0x08, 0x05, 0x8c, 0xc6, 0x6d, 0x00, 0x00, 0x00, } func (m *EnterpriseMeta) Marshal() (dAtA []byte, err error) { diff --git a/agent/agentpb/common_oss.proto b/agent/agentpb/common_oss.proto index 9e04498eb7..533faa685a 100644 --- a/agent/agentpb/common_oss.proto +++ b/agent/agentpb/common_oss.proto @@ -2,7 +2,6 @@ syntax = "proto3"; package agentpb; -import "google/protobuf/duration.proto"; // Go Modules now includes the version in the filepath for packages within GOPATH/pkg/mode // Therefore unless we want to hardcode a version here like // github.com/gogo/protobuf@v1.3.0/gogoproto/gogo.proto then the only other choice is to diff --git a/agent/agentpb/event.go b/agent/agentpb/event.go new file mode 100644 index 0000000000..dde50a8fd7 --- /dev/null +++ b/agent/agentpb/event.go @@ -0,0 +1,79 @@ +package agentpb + +import ( + fmt "fmt" + + "github.com/hashicorp/consul/acl" +) + +// EnforceACL takes an acl.Authorizer and returns the decision for whether the +// event is allowed to be sent to this client or not. +func (e *Event) EnforceACL(authz acl.Authorizer) acl.EnforcementDecision { + switch v := e.Payload.(type) { + // For now these ACL types are just used internally so we don't enforce anything for + // them. To play it safe just always deny until we expose them properly. + case *Event_ACLPolicy: + return acl.Deny + case *Event_ACLRole: + return acl.Deny + case *Event_ACLToken: + return acl.Deny + + // These are protocol messages that are always OK for the subscriber to see as + // they don't expose any information from the data model. + case *Event_ResetStream: + return acl.Allow + case *Event_ResumeStream: + return acl.Allow + case *Event_EndOfSnapshot: + return acl.Allow + // EventBatch is a special case of the above. While it does contain other + // events that might need filtering, we only use it in the transport of other + // events _after_ they've been filtered currently so we don't need to make it + // recursively return all the nested event requirements here. + case *Event_EventBatch: + return acl.Allow + + // Actual Stream events + case *Event_ServiceHealth: + // If it's not populated it's likely a bug so don't send it (or panic on + // nils). This might catch us out if we ever send partial messages but + // hopefully test will show that up early. + if v.ServiceHealth == nil || v.ServiceHealth.CheckServiceNode == nil { + return acl.Deny + } + csn := v.ServiceHealth.CheckServiceNode + + if csn.Node == nil || csn.Service == nil || + csn.Node.Node == "" || csn.Service.Service == "" { + return acl.Deny + } + + if dec := authz.NodeRead(csn.Node.Node, nil); dec != acl.Allow { + return acl.Deny + } + + // TODO(banks): need to actually populate the AuthorizerContext once we add + // Enterprise support for streaming events - they don't have enough data to + // populate it yet. + if dec := authz.ServiceRead(csn.Service.Service, nil); dec != acl.Allow { + return acl.Deny + } + return acl.Allow + + default: + panic(fmt.Sprintf("Event payload type has no ACL requirements defined: %#v", + e.Payload)) + } +} + +// EventBatchEventsFromEventSlice is a helper to convert a slice of event +// objects as used internally in Consul to a slice of pointer's to the same +// events which the generated EventBatch code needs. +func EventBatchEventsFromEventSlice(events []Event) []*Event { + ret := make([]*Event, len(events)) + for i := range events { + ret[i] = &events[i] + } + return ret +} diff --git a/agent/agentpb/event_test.go b/agent/agentpb/event_test.go new file mode 100644 index 0000000000..225e58bcbf --- /dev/null +++ b/agent/agentpb/event_test.go @@ -0,0 +1,114 @@ +package agentpb + +import ( + "testing" + + "github.com/hashicorp/consul/acl" + "github.com/stretchr/testify/require" +) + +func TestEventEnforceACL(t *testing.T) { + cases := []struct { + Name string + Event Event + ACLRules string + Want acl.EnforcementDecision + }{ + { + Name: "service health reg, blanket allow", + Event: TestEventServiceHealthRegister(t, 1, "web"), + ACLRules: `service_prefix "" { + policy = "read" + } + node_prefix "" { + policy = "read" + }`, + Want: acl.Allow, + }, + { + Name: "service health reg, deny node", + Event: TestEventServiceHealthRegister(t, 1, "web"), + ACLRules: `service_prefix "" { + policy = "read" + }`, + Want: acl.Deny, + }, + { + Name: "service health reg, deny service", + Event: TestEventServiceHealthRegister(t, 1, "web"), + ACLRules: `node_prefix "" { + policy = "read" + }`, + Want: acl.Deny, + }, + + { + Name: "internal ACL token updates denied", + Event: TestEventACLTokenUpdate(t), + ACLRules: `acl = "write"`, + Want: acl.Deny, + }, + { + Name: "internal ACL policy updates denied", + Event: TestEventACLPolicyUpdate(t), + ACLRules: `acl = "write"`, + Want: acl.Deny, + }, + { + Name: "internal ACL role updates denied", + Event: TestEventACLRoleUpdate(t), + ACLRules: `acl = "write"`, + Want: acl.Deny, + }, + + { + Name: "internal EoS allowed", + Event: TestEventEndOfSnapshot(t, Topic_ServiceHealth, 100), + ACLRules: ``, // No access to anything + Want: acl.Allow, + }, + { + Name: "internal Resume allowed", + Event: TestEventResumeStream(t, Topic_ServiceHealth, 100), + ACLRules: ``, // No access to anything + Want: acl.Allow, + }, + { + Name: "internal Reset allowed", + Event: TestEventResetStream(t, Topic_ServiceHealth, 100), + ACLRules: ``, // No access to anything + Want: acl.Allow, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + // Create an acl authorizer from the policy + policy, err := acl.NewPolicyFromSource("", 0, tc.ACLRules, acl.SyntaxCurrent, nil, nil) + require.NoError(t, err) + + authz, err := acl.NewPolicyAuthorizerWithDefaults(acl.RootAuthorizer("deny"), + []*acl.Policy{policy}, nil) + require.NoError(t, err) + + got := tc.Event.EnforceACL(authz) + require.Equal(t, tc.Want, got) + }) + } +} + +func TestEventEnforceACLCoversAllTypes(t *testing.T) { + authz := acl.RootAuthorizer("deny") + for _, payload := range allEventTypes { + e := Event{ + Topic: Topic_ServiceHealth, // Just pick any topic for now. + Index: 1234, + Payload: payload, + } + + // We don't actually care about the return type here - that's handled above, + // just that it doesn't panic because of a undefined event type. + e.EnforceACL(authz) + } +} diff --git a/agent/agentpb/event_types.structgen.go b/agent/agentpb/event_types.structgen.go new file mode 100644 index 0000000000..7401f6315c --- /dev/null +++ b/agent/agentpb/event_types.structgen.go @@ -0,0 +1,23 @@ +// Code generated by agentpb/structgen. DO NOT EDIT. + +package agentpb + +// allEventTypes is used internally in tests or places we need an exhaustive +// list of Event Payload types. We use this in tests to ensure that we don't +// miss defining something for a new test type when adding new ones. If we ever +// need to machine-genereate a human-readable list of event type strings for +// something we could easily do that here too. +var allEventTypes []isEvent_Payload + +func init() { + allEventTypes = []isEvent_Payload{ + &Event_ACLPolicy{}, + &Event_ACLRole{}, + &Event_ACLToken{}, + &Event_EndOfSnapshot{}, + &Event_EventBatch{}, + &Event_ResetStream{}, + &Event_ResumeStream{}, + &Event_ServiceHealth{}, + } +} diff --git a/agent/agentpb/structgen/structgen.go b/agent/agentpb/structgen/structgen.go index fec06244e1..863072bb22 100644 --- a/agent/agentpb/structgen/structgen.go +++ b/agent/agentpb/structgen/structgen.go @@ -2,6 +2,8 @@ package main import ( "bytes" + "fmt" + "go/ast" "go/format" "go/types" "io" @@ -39,7 +41,7 @@ var ( ) func main() { - protoStructs, err := findProtoGeneratedStructs() + protoStructs, eventStructs, err := findProtoGeneratedStructs() if err != nil { log.Fatalf("failed to find proto generated structs: %s", err) } @@ -78,17 +80,29 @@ func main() { log.Fatalf("failed to write generate tests for %s header: %s", desc.Name, err) } } + //fmt.Println(convertBuf.String()) - // Dump the file somewhere + // Dump the files somewhere err = writeToFile("./agent/agentpb/structs.structgen.go", convertBuf.Bytes()) if err != nil { log.Fatalf("Failed to write output file: %s", err) } err = writeToFile("./agent/agentpb/structs.structgen_test.go", testBuf.Bytes()) if err != nil { - log.Fatalf("Failed to write output file: %s", err) + log.Fatalf("Failed to write test file: %s", err) } + // Build simple file with all defined event types in an array so we can + // write exhaustive test checks over event types. + var eventTypesBuf bytes.Buffer + err = evTypesTpl.Execute(&eventTypesBuf, eventStructs) + if err != nil { + log.Fatalf("Failed to generate event types list: %s", err) + } + err = writeToFile("./agent/agentpb/event_types.structgen.go", eventTypesBuf.Bytes()) + if err != nil { + log.Fatalf("Failed to write event types file: %s", err) + } } func writeToFile(name string, code []byte) error { @@ -132,14 +146,15 @@ func (l structsList) Len() int { return len(l) } func (l structsList) Swap(i, j int) { l[i], l[j] = l[j], l[i] } func (l structsList) Less(i, j int) bool { return l[i].Name < l[j].Name } -func findProtoGeneratedStructs() (structsList, error) { +func findProtoGeneratedStructs() (structsList, structsList, error) { cfg := &packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo} pkgs, err := packages.Load(cfg, "github.com/hashicorp/consul/agent/agentpb") if err != nil { - return nil, err + return nil, nil, err } pkg := pkgs[0] ss := make(structsList, 0) + evs := make(structsList, 0) for ident, obj := range pkg.TypesInfo.Defs { // See where this type was defined @@ -153,9 +168,11 @@ func findProtoGeneratedStructs() (structsList, error) { continue } - // Only consider types defined in the structs protobuf mirror file + // Only consider types defined in the structs protobuf mirror file, or the + // stream events. p := pkg.Fset.Position(obj.Pos()) - if !fileMirrorsStructs(filepath.Base(p.Filename)) { + fName := filepath.Base(p.Filename) + if !fileMirrorsStructs(fName) && fName != "subscribe.pb.go" { continue } @@ -169,14 +186,26 @@ func findProtoGeneratedStructs() (structsList, error) { continue } + // Append to list of mirrored structs, unless this is subscribe.pb.go where + // we just need the Event payload types. + collect := func(fName string, id *ast.Ident, t *types.Struct) { + if fName == "subscribe.pb.go" { + if strings.HasPrefix(id.Name, "Event_") { + evs = append(evs, structDesc{id.Name, nil}) + } + } else { + ss = append(ss, structDesc{id.Name, t}) + } + } + // See if it's a struct type switch tt := obj.Type().(type) { case *types.Struct: - ss = append(ss, structDesc{ident.Name, tt}) + collect(fName, ident, tt) case *types.Named: switch st := tt.Underlying().(type) { case *types.Struct: - ss = append(ss, structDesc{ident.Name, st}) + collect(fName, ident, st) default: continue } @@ -187,8 +216,9 @@ func findProtoGeneratedStructs() (structsList, error) { // Sort them to keep the generated file deterministic sort.Sort(ss) + sort.Sort(evs) - return ss, nil + return ss, evs, nil } func shouldIgnoreType(name string) bool { @@ -348,6 +378,13 @@ func genConvert(w io.Writer, name string, s, structsType *types.Struct) error { structsTI := analyzeFieldType(structsType.Field(i)) ti.StructsTypeInfo = &structsTI + if strings.HasSuffix(ti.Type, "invalid type") { + return fmt.Errorf("protobuf field %s.%s has invalid type", name, ti.Name) + } + if strings.HasSuffix(structsTI.Type, "invalid type") { + return fmt.Errorf("structs field %s.%s has invalid type", name, structsTI.Name) + } + buf.Reset() err := toStructsTpl.ExecuteTemplate(&buf, ti.Template, ti) if err != nil { @@ -783,3 +820,23 @@ func fieldTypeInfoForType(t types.Type) fieldTypeInfo { } return ti } + +var evTypesTpl = template.Must(template.New("test").Parse(`// Code generated by agentpb/structgen. DO NOT EDIT. + +package agentpb + +// allEventTypes is used internally in tests or places we need an exhaustive +// list of Event Payload types. We use this in tests to ensure that we don't +// miss defining something for a new test type when adding new ones. If we ever +// need to machine-genereate a human-readable list of event type strings for +// something we could easily do that here too. +var allEventTypes []isEvent_Payload + +func init() { + allEventTypes = []isEvent_Payload{ + {{ range . -}} + &{{ .Name }}{}, + {{ end }} + } +} +`)) diff --git a/agent/agentpb/testing_events.go b/agent/agentpb/testing_events.go new file mode 100644 index 0000000000..ea7b25af06 --- /dev/null +++ b/agent/agentpb/testing_events.go @@ -0,0 +1,236 @@ +package agentpb + +import ( + fmt "fmt" + + "github.com/hashicorp/consul/types" + "github.com/mitchellh/go-testing-interface" +) + +// TestEventEndOfSnapshot returns a valid EndOfSnapshot event on the given topic +// and index. +func TestEventEndOfSnapshot(t testing.T, topic Topic, index uint64) Event { + return Event{ + Topic: topic, + Index: index, + Payload: &Event_EndOfSnapshot{ + EndOfSnapshot: true, + }, + } +} + +// TestEventResetStream returns a valid ResetStream event on the given topic +// and index. +func TestEventResetStream(t testing.T, topic Topic, index uint64) Event { + return Event{ + Topic: topic, + Index: index, + Payload: &Event_ResetStream{ + ResetStream: true, + }, + } +} + +// TestEventResumeStream returns a valid ResumeStream event on the given topic +// and index. +func TestEventResumeStream(t testing.T, topic Topic, index uint64) Event { + return Event{ + Topic: topic, + Index: index, + Payload: &Event_ResumeStream{ + ResumeStream: true, + }, + } +} + +// TestEventBatch returns a valid EventBatch event it assumes service health +// topic, an index of 100 and contains two health registrations. +func TestEventBatch(t testing.T) Event { + e1 := TestEventServiceHealthRegister(t, 1, "web") + e2 := TestEventServiceHealthRegister(t, 1, "api") + return Event{ + Topic: Topic_ServiceHealth, + Index: 100, + Payload: &Event_EventBatch{ + EventBatch: &EventBatch{ + Events: []*Event{&e1, &e2}, + }, + }, + } +} + +// TestEventACLTokenUpdate returns a valid ACLToken event. +func TestEventACLTokenUpdate(t testing.T) Event { + return Event{ + Topic: Topic_ACLTokens, + Index: 100, + Payload: &Event_ACLToken{ + ACLToken: &ACLTokenUpdate{ + Op: ACLOp_Update, + Token: &ACLTokenIdentifier{ + AccessorID: "adfa4d37-560f-4824-a121-356064a7a2ea", + SecretID: "f58b28f9-42a4-48b2-a08c-eba8ff6560f1", + }, + }, + }, + } +} + +// TestEventACLPolicyUpdate returns a valid ACLPolicy event. +func TestEventACLPolicyUpdate(t testing.T) Event { + return Event{ + Topic: Topic_ACLPolicies, + Index: 100, + Payload: &Event_ACLPolicy{ + ACLPolicy: &ACLPolicyUpdate{ + Op: ACLOp_Update, + PolicyID: "f1df7f3e-6732-45e8-9a3d-ada2a22fa336", + }, + }, + } +} + +// TestEventACLRoleUpdate returns a valid ACLRole event. +func TestEventACLRoleUpdate(t testing.T) Event { + return Event{ + Topic: Topic_ACLRoles, + Index: 100, + Payload: &Event_ACLRole{ + ACLRole: &ACLRoleUpdate{ + Op: ACLOp_Update, + RoleID: "40fee72a-510f-4de7-8c91-e05d42512b9f", + }, + }, + } +} + +// TestEventServiceHealthRegister returns a realistically populated service +// health registration event for tests in other packages. The nodeNum is a +// logical node and is used to create the node name ("node%d") but also change +// the node ID and IP address to make it a little more realistic for cases that +// need that. nodeNum should be less than 64k to make the IP address look +// realistic. Any other changes can be made on the returned event to avoid +// adding too many options to callers. +func TestEventServiceHealthRegister(t testing.T, nodeNum int, svc string) Event { + + node := fmt.Sprintf("node%d", nodeNum) + nodeID := types.NodeID(fmt.Sprintf("11111111-2222-3333-4444-%012d", nodeNum)) + addr := fmt.Sprintf("10.10.%d.%d", nodeNum/256, nodeNum%256) + + return Event{ + Topic: Topic_ServiceHealth, + Key: svc, + Index: 100, + Payload: &Event_ServiceHealth{ + ServiceHealth: &ServiceHealthUpdate{ + Op: CatalogOp_Register, + CheckServiceNode: &CheckServiceNode{ + Node: &Node{ + ID: nodeID, + Node: node, + Address: addr, + Datacenter: "dc1", + RaftIndex: RaftIndex{ + CreateIndex: 100, + ModifyIndex: 100, + }, + }, + Service: &NodeService{ + ID: svc, + Service: svc, + Port: 8080, + Weights: &Weights{ + Passing: 1, + Warning: 1, + }, + // Empty sadness + Proxy: ConnectProxyConfig{ + MeshGateway: &MeshGatewayConfig{}, + Expose: &ExposeConfig{}, + }, + EnterpriseMeta: &EnterpriseMeta{}, + RaftIndex: RaftIndex{ + CreateIndex: 100, + ModifyIndex: 100, + }, + }, + Checks: []*HealthCheck{ + &HealthCheck{ + Node: node, + CheckID: "serf-health", + Name: "serf-health", + Status: "passing", + EnterpriseMeta: &EnterpriseMeta{}, + RaftIndex: RaftIndex{ + CreateIndex: 100, + ModifyIndex: 100, + }, + }, + &HealthCheck{ + Node: node, + CheckID: types.CheckID("service:" + svc), + Name: "service:" + svc, + ServiceID: svc, + ServiceName: svc, + Type: "ttl", + Status: "passing", + EnterpriseMeta: &EnterpriseMeta{}, + RaftIndex: RaftIndex{ + CreateIndex: 100, + ModifyIndex: 100, + }, + }, + }, + }, + }, + }, + } +} + +// TestEventServiceHealthDeregister returns a realistically populated service +// health deregistration event for tests in other packages. The nodeNum is a +// logical node and is used to create the node name ("node%d") but also change +// the node ID and IP address to make it a little more realistic for cases that +// need that. nodeNum should be less than 64k to make the IP address look +// realistic. Any other changes can be made on the returned event to avoid +// adding too many options to callers. +func TestEventServiceHealthDeregister(t testing.T, nodeNum int, svc string) Event { + + node := fmt.Sprintf("node%d", nodeNum) + + return Event{ + Topic: Topic_ServiceHealth, + Key: svc, + Index: 100, + Payload: &Event_ServiceHealth{ + ServiceHealth: &ServiceHealthUpdate{ + Op: CatalogOp_Deregister, + CheckServiceNode: &CheckServiceNode{ + Node: &Node{ + Node: node, + }, + Service: &NodeService{ + ID: svc, + Service: svc, + Port: 8080, + Weights: &Weights{ + Passing: 1, + Warning: 1, + }, + // Empty sadness + Proxy: ConnectProxyConfig{ + MeshGateway: &MeshGatewayConfig{}, + Expose: &ExposeConfig{}, + }, + EnterpriseMeta: &EnterpriseMeta{}, + RaftIndex: RaftIndex{ + // The original insertion index since a delete doesn't update this. + CreateIndex: 10, + ModifyIndex: 10, + }, + }, + }, + }, + }, + } +} diff --git a/agent/consul/acl.go b/agent/consul/acl.go index e1e45056a3..a03dd8665d 100644 --- a/agent/consul/acl.go +++ b/agent/consul/acl.go @@ -8,6 +8,7 @@ import ( metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/agentpb" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/logging" "github.com/hashicorp/go-hclog" @@ -1720,6 +1721,43 @@ func (f *aclFilter) filterGatewayServices(mappings *structs.GatewayServices) { *mappings = ret } +func (f *aclFilter) allowStreamEvent(event *agentpb.Event) bool { + // Fast path if ACLs are not enabled + if f.authorizer == nil { + return true + } + return event.EnforceACL(f.authorizer) == acl.Allow +} + +func (f *aclFilter) filterStreamEvents(events *[]agentpb.Event) { + // Fast path for the common case of only 1 event since we can avoid slice + // allocation in the hot path of every single update event delivered in vast + // majority of cases with this. Note that this is called _per event/item_ when + // sending snapshots which is a lot worse than being called once on regular + // result. + if len(*events) == 1 { + if f.allowStreamEvent(&(*events)[0]) { + return + } + // Was denied, truncate the input events stream to remove the single event + *events = (*events)[:0] + return + } + + filtered := make([]agentpb.Event, 0, len(*events)) + + for idx := range *events { + // Get pointer to the actual event. We don't use _, event ranging to save to + // confusion of making a local copy, this is more explicit. + event := &(*events)[idx] + if f.allowStreamEvent(event) { + filtered = append(filtered, *event) + } + } + + *events = filtered +} + func (r *ACLResolver) filterACLWithAuthorizer(authorizer acl.Authorizer, subj interface{}) error { if authorizer == nil { return nil @@ -1808,6 +1846,9 @@ func (r *ACLResolver) filterACLWithAuthorizer(authorizer acl.Authorizer, subj in case *structs.GatewayServices: filt.filterGatewayServices(v) + case *[]agentpb.Event: + filt.filterStreamEvents(v) + default: panic(fmt.Errorf("Unhandled type passed to ACL filter: %T %#v", subj, subj)) } diff --git a/agent/consul/catalog_endpoint_test.go b/agent/consul/catalog_endpoint_test.go index 55e39067f9..d131b70d29 100644 --- a/agent/consul/catalog_endpoint_test.go +++ b/agent/consul/catalog_endpoint_test.go @@ -2558,13 +2558,17 @@ func TestCatalog_Register_FailedCase1(t *testing.T) { } func testACLFilterServer(t *testing.T) (dir, token string, srv *Server, codec rpc.ClientCodec) { - dir, srv = testServerWithConfig(t, func(c *Config) { + return testACLFilterServerWithConfigFn(t, func(c *Config) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" c.ACLDefaultPolicy = "deny" c.ACLEnforceVersion8 = false }) +} + +func testACLFilterServerWithConfigFn(t *testing.T, fn func(*Config)) (dir, token string, srv *Server, codec rpc.ClientCodec) { + dir, srv = testServerWithConfig(t, fn) codec = rpcClient(t, srv) testrpc.WaitForLeader(t, srv.RPC, "dc1") diff --git a/agent/consul/grpc_service.go b/agent/consul/grpc_service.go new file mode 100644 index 0000000000..48021f974a --- /dev/null +++ b/agent/consul/grpc_service.go @@ -0,0 +1,28 @@ +package consul + +import "github.com/hashicorp/consul/logging" + +// GRPCService is the implementation of the gRPC Consul service defined in +// agentpb/consul.proto. Each RPC is implemented in a separate *_grpc_endpoint +// files as methods on this object. +type GRPCService struct { + srv *Server + + // gRPC needs each RPC in the service definition attached to a single object + // as a method to implement the interface. We want to use a separate named + // logger for each endpit to match net/rpc usage but also would be nice to be + // able to just use the standard s.logger for calls rather than seperately + // named loggers for each RPC method. So each RPC method is actually defined + // on a separate object with a `logger` field and then those are all ebedded + // here to make this object implement the full interface. + GRPCSubscribeHandler +} + +func NewGRPCService(s *Server) *GRPCService { + return &GRPCService{ + GRPCSubscribeHandler: GRPCSubscribeHandler{ + srv: s, + logger: s.loggers.Named(logging.Subscribe), + }, + } +} diff --git a/agent/consul/server.go b/agent/consul/server.go index c541b853f4..687712523d 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -907,7 +907,7 @@ func (s *Server) setupGRPC() error { grpc.StatsHandler(grpcStatsHandler), grpc.StreamInterceptor(GRPCCountingStreamInterceptor), ) - //stream.RegisterConsulServer(srv, &ConsulGRPCAdapter{Health{s}}) + agentpb.RegisterConsulServer(srv, NewGRPCService(s)) if s.config.GRPCTestServerEnabled { agentpb.RegisterTestServer(srv, &GRPCTest{srv: s}) } diff --git a/agent/consul/state/catalog_events_test.go b/agent/consul/state/catalog_events_test.go index 3706e6afad..32e1d5698b 100644 --- a/agent/consul/state/catalog_events_test.go +++ b/agent/consul/state/catalog_events_test.go @@ -62,74 +62,14 @@ func testServiceRegistration(t *testing.T, svc string, opts ...regOption) *struc } func testServiceHealthEvent(t *testing.T, svc string, opts ...eventOption) agentpb.Event { - e := agentpb.Event{ - Topic: agentpb.Topic_ServiceHealth, - Key: svc, - Index: 100, - Payload: &agentpb.Event_ServiceHealth{ - ServiceHealth: &agentpb.ServiceHealthUpdate{ - Op: agentpb.CatalogOp_Register, - CheckServiceNode: &agentpb.CheckServiceNode{ - Node: &agentpb.Node{ - ID: "11111111-2222-3333-4444-555555555555", - Node: "node1", - Address: "10.10.10.10", - Datacenter: "dc1", - RaftIndex: agentpb.RaftIndex{ - CreateIndex: 100, - ModifyIndex: 100, - }, - }, - Service: &agentpb.NodeService{ - ID: svc, - Service: svc, - Port: 8080, - Weights: &agentpb.Weights{ - Passing: 1, - Warning: 1, - }, - // Empty sadness - Proxy: agentpb.ConnectProxyConfig{ - MeshGateway: &agentpb.MeshGatewayConfig{}, - Expose: &agentpb.ExposeConfig{}, - }, - EnterpriseMeta: &agentpb.EnterpriseMeta{}, - RaftIndex: agentpb.RaftIndex{ - CreateIndex: 100, - ModifyIndex: 100, - }, - }, - Checks: []*agentpb.HealthCheck{ - &agentpb.HealthCheck{ - Node: "node1", - CheckID: "serf-health", - Name: "serf-health", - Status: "passing", - EnterpriseMeta: &agentpb.EnterpriseMeta{}, - RaftIndex: agentpb.RaftIndex{ - CreateIndex: 100, - ModifyIndex: 100, - }, - }, - &agentpb.HealthCheck{ - Node: "node1", - CheckID: types.CheckID("service:" + svc), - Name: "service:" + svc, - ServiceID: svc, - ServiceName: svc, - Type: "ttl", - Status: "passing", - EnterpriseMeta: &agentpb.EnterpriseMeta{}, - RaftIndex: agentpb.RaftIndex{ - CreateIndex: 100, - ModifyIndex: 100, - }, - }, - }, - }, - }, - }, - } + e := agentpb.TestEventServiceHealthRegister(t, 1, svc) + + // Normalize a few things that are different in the generic event which was + // based on original code here but made more general. This means we don't have + // to change all the test loads... + csn := e.GetServiceHealth().CheckServiceNode + csn.Node.ID = "11111111-2222-3333-4444-555555555555" + csn.Node.Address = "10.10.10.10" for _, opt := range opts { err := opt(&e) @@ -139,41 +79,7 @@ func testServiceHealthEvent(t *testing.T, svc string, opts ...eventOption) agent } func testServiceHealthDeregistrationEvent(t *testing.T, svc string, opts ...eventOption) agentpb.Event { - e := agentpb.Event{ - Topic: agentpb.Topic_ServiceHealth, - Key: svc, - Index: 100, - Payload: &agentpb.Event_ServiceHealth{ - ServiceHealth: &agentpb.ServiceHealthUpdate{ - Op: agentpb.CatalogOp_Deregister, - CheckServiceNode: &agentpb.CheckServiceNode{ - Node: &agentpb.Node{ - Node: "node1", - }, - Service: &agentpb.NodeService{ - ID: svc, - Service: svc, - Port: 8080, - Weights: &agentpb.Weights{ - Passing: 1, - Warning: 1, - }, - // Empty sadness - Proxy: agentpb.ConnectProxyConfig{ - MeshGateway: &agentpb.MeshGatewayConfig{}, - Expose: &agentpb.ExposeConfig{}, - }, - EnterpriseMeta: &agentpb.EnterpriseMeta{}, - RaftIndex: agentpb.RaftIndex{ - // The original insertion index since a delete doesn't update this. - CreateIndex: 10, - ModifyIndex: 10, - }, - }, - }, - }, - }, - } + e := agentpb.TestEventServiceHealthDeregister(t, 1, svc) for _, opt := range opts { err := opt(&e) require.NoError(t, err) @@ -1469,8 +1375,6 @@ func requireEventsInCorrectPartialOrder(t *testing.T, want, got []agentpb.Event, gotParts[k] = append(gotParts[k], e) } - //q.Q(wantParts, gotParts) - for k, want := range wantParts { require.Equal(t, want, gotParts[k], "got incorrect events for partition: %s", k) } diff --git a/agent/consul/subscribe_grpc_endpoint.go b/agent/consul/subscribe_grpc_endpoint.go new file mode 100644 index 0000000000..cd90b5c0d6 --- /dev/null +++ b/agent/consul/subscribe_grpc_endpoint.go @@ -0,0 +1,191 @@ +package consul + +import ( + "github.com/hashicorp/consul/agent/agentpb" + "github.com/hashicorp/consul/agent/consul/stream" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-uuid" +) + +// GRPCSubscribeHandler is the type that implements the gRPC Subscribe RPC +// method. It wraps a Subscribe-scoped logger and will be embedded in +// GRPCService to implement the full service. +type GRPCSubscribeHandler struct { + srv *Server + logger hclog.Logger +} + +// Subscribe opens a long-lived gRPC stream which sends an initial snapshot +// of state for the requested topic, then only sends updates. +func (h *GRPCSubscribeHandler) Subscribe(req *agentpb.SubscribeRequest, serverStream agentpb.Consul_SubscribeServer) error { + + // streamID is just used for message correlation in trace logs and not + // populated normally. + var streamID string + var err error + + if h.logger.IsTrace() { + // TODO(banks) it might be nice one day to replace this with OpenTracing ID + // if one is set etc. but probably pointless until we support that properly + // in other places so it's actually propagated properly. For now this just + // makes lifetime of a stream more traceable in our regular server logs for + // debugging/dev. + streamID, err = uuid.GenerateUUID() + if err != nil { + return err + } + } + + // Forward the request to a remote DC if applicable. + if req.Datacenter != "" && req.Datacenter != h.srv.config.Datacenter { + return h.forwardAndProxy(req, serverStream, streamID) + } + + h.srv.logger.Trace("new subscription", + "topic", req.Topic.String(), + "key", req.Key, + "index", req.Index, + "stream_id", streamID, + ) + + var sentCount uint64 + defer func() { + h.srv.logger.Trace("subscription closed", + "stream_id", streamID, + ) + }() + + // Resolve the token and create the ACL filter. + // TODO: handle token expiry gracefully... + authz, err := h.srv.ResolveToken(req.Token) + if err != nil { + return err + } + aclFilter := newACLFilter(authz, h.srv.logger, h.srv.config.ACLEnforceVersion8) + + state := h.srv.fsm.State() + + // Register a subscription on this topic/key with the FSM. + sub, err := state.Subscribe(serverStream.Context(), req) + if err != nil { + return err + } + defer state.Unsubscribe(req) + + // Deliver the events + for { + events, err := sub.Next() + if err == stream.ErrSubscriptionReload { + event := agentpb.Event{ + Payload: &agentpb.Event_ResetStream{ResetStream: true}, + } + if err := serverStream.Send(&event); err != nil { + return err + } + h.srv.logger.Trace("subscription reloaded", + "stream_id", streamID, + ) + return nil + } + if err != nil { + return err + } + + aclFilter.filterStreamEvents(&events) + + snapshotDone := false + if len(events) == 1 { + if events[0].GetEndOfSnapshot() { + snapshotDone = true + h.srv.logger.Trace("snapshot complete", + "index", events[0].Index, + "sent", sentCount, + "stream_id", streamID, + ) + } else if events[0].GetResumeStream() { + snapshotDone = true + h.srv.logger.Trace("resuming stream", + "index", events[0].Index, + "sent", sentCount, + "stream_id", streamID, + ) + } else if snapshotDone { + // Count this event too in the normal case as "sent" the above cases + // only show the number of events sent _before_ the snapshot ended. + h.srv.logger.Trace("sending events", + "index", events[0].Index, + "sent", sentCount, + "batch_size", 1, + "stream_id", streamID, + ) + } + sentCount++ + if err := serverStream.Send(&events[0]); err != nil { + return err + } + } else if len(events) > 1 { + e := &agentpb.Event{ + Topic: req.Topic, + Key: req.Key, + Index: events[0].Index, + Payload: &agentpb.Event_EventBatch{ + EventBatch: &agentpb.EventBatch{ + Events: agentpb.EventBatchEventsFromEventSlice(events), + }, + }, + } + sentCount += uint64(len(events)) + h.srv.logger.Trace("sending events", + "index", events[0].Index, + "sent", sentCount, + "batch_size", len(events), + "stream_id", streamID, + ) + if err := serverStream.Send(e); err != nil { + return err + } + } + } +} + +func (h *GRPCSubscribeHandler) forwardAndProxy(req *agentpb.SubscribeRequest, + serverStream agentpb.Consul_SubscribeServer, streamID string) error { + + conn, err := h.srv.grpcClient.GRPCConn(req.Datacenter) + if err != nil { + return err + } + + h.logger.Trace("forwarding to another DC", + "dc", req.Datacenter, + "topic", req.Topic.String(), + "key", req.Key, + "index", req.Index, + "stream_id", streamID, + ) + + defer func() { + h.logger.Trace("forwarded stream closed", + "dc", req.Datacenter, + "stream_id", streamID, + ) + }() + + // Open a Subscribe call to the remote DC. + client := agentpb.NewConsulClient(conn) + streamHandle, err := client.Subscribe(serverStream.Context(), req) + if err != nil { + return err + } + + // Relay the events back to the client. + for { + event, err := streamHandle.Recv() + if err != nil { + return err + } + if err := serverStream.Send(event); err != nil { + return err + } + } +} diff --git a/agent/consul/subscribe_grpc_endpoint_test.go b/agent/consul/subscribe_grpc_endpoint_test.go new file mode 100644 index 0000000000..6b10b2adb9 --- /dev/null +++ b/agent/consul/subscribe_grpc_endpoint_test.go @@ -0,0 +1,1455 @@ +package consul + +import ( + "context" + "encoding/json" + "fmt" + "io" + "os" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/agentpb" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/consul/types" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +// requireEqualProtos is a helper that runs arrays or structures containing +// proto buf messages through JSON encoding before comparing/diffing them. This +// is necessary because require.Equal doesn't compare them equal and generates +// really unhelpful output in this case for some reason. +func requireEqualProtos(t *testing.T, want, got interface{}) { + t.Helper() + gotJSON, err := json.Marshal(got) + require.NoError(t, err) + expectJSON, err := json.Marshal(want) + require.NoError(t, err) + require.JSONEq(t, string(expectJSON), string(gotJSON)) +} + +func TestStreaming_Subscribe(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir1, server := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir1) + defer server.Shutdown() + codec := rpcClient(t, server) + defer codec.Close() + + dir2, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir2) + defer client.Shutdown() + + // Try to join + testrpc.WaitForLeader(t, server.RPC, "dc1") + joinLAN(t, client, server) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Register a dummy node with a service we don't care about, to make sure + // we don't see updates for it. + { + req := &structs.RegisterRequest{ + Node: "other", + Address: "2.3.4.5", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "api1", + Service: "api", + Address: "2.3.4.5", + Port: 9000, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + } + + // Register a dummy node with our service on it. + { + req := &structs.RegisterRequest{ + Node: "node1", + Address: "3.4.5.6", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + } + + // Register a test node to be updated later. + req := &structs.RegisterRequest{ + Node: "node2", + Address: "1.2.3.4", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + + // Start a Subscribe call to our streaming endpoint. + conn, err := client.GRPCConn() + require.NoError(err) + + streamClient := agentpb.NewConsulClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + streamHandle, err := streamClient.Subscribe(ctx, &agentpb.SubscribeRequest{ + Topic: agentpb.Topic_ServiceHealth, + Key: "redis", + }) + require.NoError(err) + + // Start a goroutine to read updates off the agentpb. + eventCh := make(chan *agentpb.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + var snapshotEvents []*agentpb.Event + for i := 0; i < 3; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(3 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + + expected := []*agentpb.Event{ + { + Topic: agentpb.Topic_ServiceHealth, + Key: "redis", + Payload: &agentpb.Event_ServiceHealth{ + ServiceHealth: &agentpb.ServiceHealthUpdate{ + Op: agentpb.CatalogOp_Register, + CheckServiceNode: &agentpb.CheckServiceNode{ + Node: &agentpb.Node{ + Node: "node1", + Datacenter: "dc1", + Address: "3.4.5.6", + }, + Service: &agentpb.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + Weights: &agentpb.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: agentpb.ConnectProxyConfig{ + MeshGateway: &agentpb.MeshGatewayConfig{}, + Expose: &agentpb.ExposeConfig{}, + }, + EnterpriseMeta: &agentpb.EnterpriseMeta{}, + }, + }, + }, + }, + }, + { + Topic: agentpb.Topic_ServiceHealth, + Key: "redis", + Payload: &agentpb.Event_ServiceHealth{ + ServiceHealth: &agentpb.ServiceHealthUpdate{ + Op: agentpb.CatalogOp_Register, + CheckServiceNode: &agentpb.CheckServiceNode{ + Node: &agentpb.Node{ + Node: "node2", + Datacenter: "dc1", + Address: "1.2.3.4", + }, + Service: &agentpb.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + Weights: &agentpb.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: agentpb.ConnectProxyConfig{ + MeshGateway: &agentpb.MeshGatewayConfig{}, + Expose: &agentpb.ExposeConfig{}, + }, + EnterpriseMeta: &agentpb.EnterpriseMeta{}, + }, + }, + }, + }, + }, + { + Topic: agentpb.Topic_ServiceHealth, + Key: "redis", + Payload: &agentpb.Event_EndOfSnapshot{EndOfSnapshot: true}, + }, + } + + require.Len(snapshotEvents, 3) + for i := 0; i < 2; i++ { + // Fix up the index + expected[i].Index = snapshotEvents[i].Index + node := expected[i].GetServiceHealth().CheckServiceNode + node.Node.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Node.RaftIndex + node.Service.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Service.RaftIndex + } + // Fix index on snapshot event + expected[2].Index = snapshotEvents[2].Index + + requireEqualProtos(t, expected, snapshotEvents) + + // Update the registration by adding a check. + req.Check = &structs.HealthCheck{ + Node: "node2", + CheckID: types.CheckID("check1"), + ServiceID: "redis1", + ServiceName: "redis", + Name: "check 1", + } + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + + // Make sure we get the event for the diff. + select { + case event := <-eventCh: + expected := &agentpb.Event{ + Topic: agentpb.Topic_ServiceHealth, + Key: "redis", + Payload: &agentpb.Event_ServiceHealth{ + ServiceHealth: &agentpb.ServiceHealthUpdate{ + Op: agentpb.CatalogOp_Register, + CheckServiceNode: &agentpb.CheckServiceNode{ + Node: &agentpb.Node{ + Node: "node2", + Datacenter: "dc1", + Address: "1.2.3.4", + RaftIndex: agentpb.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, + }, + Service: &agentpb.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + RaftIndex: agentpb.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, + Weights: &agentpb.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: agentpb.ConnectProxyConfig{ + MeshGateway: &agentpb.MeshGatewayConfig{}, + Expose: &agentpb.ExposeConfig{}, + }, + EnterpriseMeta: &agentpb.EnterpriseMeta{}, + }, + Checks: []*agentpb.HealthCheck{ + { + CheckID: "check1", + Name: "check 1", + Node: "node2", + Status: "critical", + ServiceID: "redis1", + ServiceName: "redis", + RaftIndex: agentpb.RaftIndex{CreateIndex: 14, ModifyIndex: 14}, + EnterpriseMeta: &agentpb.EnterpriseMeta{}, + }, + }, + }, + }, + }, + } + // Fix up the index + expected.Index = event.Index + node := expected.GetServiceHealth().CheckServiceNode + node.Node.RaftIndex = event.GetServiceHealth().CheckServiceNode.Node.RaftIndex + node.Service.RaftIndex = event.GetServiceHealth().CheckServiceNode.Service.RaftIndex + node.Checks[0].RaftIndex = event.GetServiceHealth().CheckServiceNode.Checks[0].RaftIndex + requireEqualProtos(t, expected, event) + case <-time.After(3 * time.Second): + t.Fatal("never got event") + } + + // Wait and make sure there aren't any more events coming. + select { + case event := <-eventCh: + t.Fatalf("got another event: %v", event) + case <-time.After(500 * time.Millisecond): + } +} + +func TestStreaming_Subscribe_MultiDC(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir1, server1 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir1) + defer server1.Shutdown() + + dir2, server2 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc2" + c.Bootstrap = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir2) + defer server2.Shutdown() + codec := rpcClient(t, server2) + defer codec.Close() + + dir3, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir3) + defer client.Shutdown() + + // Join the servers via WAN + joinWAN(t, server2, server1) + testrpc.WaitForLeader(t, server1.RPC, "dc1") + testrpc.WaitForLeader(t, server2.RPC, "dc2") + + joinLAN(t, client, server1) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Register a dummy node in dc2 with a service we don't care about, + // to make sure we don't see updates for it. + { + req := &structs.RegisterRequest{ + Node: "other", + Address: "2.3.4.5", + Datacenter: "dc2", + Service: &structs.NodeService{ + ID: "api1", + Service: "api", + Address: "2.3.4.5", + Port: 9000, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + } + + // Register a dummy node with our service on it, again in dc2. + { + req := &structs.RegisterRequest{ + Node: "node1", + Address: "3.4.5.6", + Datacenter: "dc2", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + } + + // Register a test node in dc2 to be updated later. + req := &structs.RegisterRequest{ + Node: "node2", + Address: "1.2.3.4", + Datacenter: "dc2", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + + // Start a cross-DC Subscribe call to our streaming endpoint, specifying dc2. + conn, err := client.GRPCConn() + require.NoError(err) + + streamClient := agentpb.NewConsulClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + streamHandle, err := streamClient.Subscribe(ctx, &agentpb.SubscribeRequest{ + Topic: agentpb.Topic_ServiceHealth, + Key: "redis", + Datacenter: "dc2", + }) + require.NoError(err) + + // Start a goroutine to read updates off the agentpb. + eventCh := make(chan *agentpb.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + var snapshotEvents []*agentpb.Event + for i := 0; i < 3; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(3 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + + expected := []*agentpb.Event{ + { + Topic: agentpb.Topic_ServiceHealth, + Key: "redis", + Payload: &agentpb.Event_ServiceHealth{ + ServiceHealth: &agentpb.ServiceHealthUpdate{ + Op: agentpb.CatalogOp_Register, + CheckServiceNode: &agentpb.CheckServiceNode{ + Node: &agentpb.Node{ + Node: "node1", + Datacenter: "dc2", + Address: "3.4.5.6", + }, + Service: &agentpb.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + Weights: &agentpb.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: agentpb.ConnectProxyConfig{ + MeshGateway: &agentpb.MeshGatewayConfig{}, + Expose: &agentpb.ExposeConfig{}, + }, + EnterpriseMeta: &agentpb.EnterpriseMeta{}, + }, + }, + }, + }, + }, + { + Topic: agentpb.Topic_ServiceHealth, + Key: "redis", + Payload: &agentpb.Event_ServiceHealth{ + ServiceHealth: &agentpb.ServiceHealthUpdate{ + Op: agentpb.CatalogOp_Register, + CheckServiceNode: &agentpb.CheckServiceNode{ + Node: &agentpb.Node{ + Node: "node2", + Datacenter: "dc2", + Address: "1.2.3.4", + }, + Service: &agentpb.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + Weights: &agentpb.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: agentpb.ConnectProxyConfig{ + MeshGateway: &agentpb.MeshGatewayConfig{}, + Expose: &agentpb.ExposeConfig{}, + }, + EnterpriseMeta: &agentpb.EnterpriseMeta{}, + }, + }, + }, + }, + }, + { + Topic: agentpb.Topic_ServiceHealth, + Key: "redis", + Payload: &agentpb.Event_EndOfSnapshot{EndOfSnapshot: true}, + }, + } + + require.Len(snapshotEvents, 3) + for i := 0; i < 2; i++ { + // Fix up the index + expected[i].Index = snapshotEvents[i].Index + node := expected[i].GetServiceHealth().CheckServiceNode + node.Node.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Node.RaftIndex + node.Service.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Service.RaftIndex + } + expected[2].Index = snapshotEvents[2].Index + requireEqualProtos(t, expected, snapshotEvents) + + // Update the registration by adding a check. + req.Check = &structs.HealthCheck{ + Node: "node2", + CheckID: types.CheckID("check1"), + ServiceID: "redis1", + ServiceName: "redis", + Name: "check 1", + } + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + + // Make sure we get the event for the diff. + select { + case event := <-eventCh: + expected := &agentpb.Event{ + Topic: agentpb.Topic_ServiceHealth, + Key: "redis", + Payload: &agentpb.Event_ServiceHealth{ + ServiceHealth: &agentpb.ServiceHealthUpdate{ + Op: agentpb.CatalogOp_Register, + CheckServiceNode: &agentpb.CheckServiceNode{ + Node: &agentpb.Node{ + Node: "node2", + Datacenter: "dc2", + Address: "1.2.3.4", + RaftIndex: agentpb.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, + }, + Service: &agentpb.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + RaftIndex: agentpb.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, + Weights: &agentpb.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: agentpb.ConnectProxyConfig{ + MeshGateway: &agentpb.MeshGatewayConfig{}, + Expose: &agentpb.ExposeConfig{}, + }, + EnterpriseMeta: &agentpb.EnterpriseMeta{}, + }, + Checks: []*agentpb.HealthCheck{ + { + CheckID: "check1", + Name: "check 1", + Node: "node2", + Status: "critical", + ServiceID: "redis1", + ServiceName: "redis", + RaftIndex: agentpb.RaftIndex{CreateIndex: 14, ModifyIndex: 14}, + EnterpriseMeta: &agentpb.EnterpriseMeta{}, + }, + }, + }, + }, + }, + } + // Fix up the index + expected.Index = event.Index + node := expected.GetServiceHealth().CheckServiceNode + node.Node.RaftIndex = event.GetServiceHealth().CheckServiceNode.Node.RaftIndex + node.Service.RaftIndex = event.GetServiceHealth().CheckServiceNode.Service.RaftIndex + node.Checks[0].RaftIndex = event.GetServiceHealth().CheckServiceNode.Checks[0].RaftIndex + requireEqualProtos(t, expected, event) + case <-time.After(3 * time.Second): + t.Fatal("never got event") + } + + // Wait and make sure there aren't any more events coming. + select { + case event := <-eventCh: + t.Fatalf("got another event: %v", event) + case <-time.After(500 * time.Millisecond): + } +} + +func TestStreaming_Subscribe_SkipSnapshot(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir1, server := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir1) + defer server.Shutdown() + codec := rpcClient(t, server) + defer codec.Close() + + dir2, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir2) + defer client.Shutdown() + + // Try to join + testrpc.WaitForLeader(t, server.RPC, "dc1") + joinLAN(t, client, server) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Register a dummy node with our service on it. + { + req := &structs.RegisterRequest{ + Node: "node1", + Address: "3.4.5.6", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + } + + // Start a Subscribe call to our streaming endpoint. + conn, err := client.GRPCConn() + require.NoError(err) + + streamClient := agentpb.NewConsulClient(conn) + + var index uint64 + { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + streamHandle, err := streamClient.Subscribe(ctx, &agentpb.SubscribeRequest{Topic: agentpb.Topic_ServiceHealth, Key: "redis"}) + require.NoError(err) + + // Start a goroutine to read updates off the agentpb. + eventCh := make(chan *agentpb.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + var snapshotEvents []*agentpb.Event + for i := 0; i < 2; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(3 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + + // Save the index from the event + index = snapshotEvents[0].Index + } + + // Start another Subscribe call passing the index from the last event. + { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + streamHandle, err := streamClient.Subscribe(ctx, &agentpb.SubscribeRequest{ + Topic: agentpb.Topic_ServiceHealth, + Key: "redis", + Index: index, + }) + require.NoError(err) + + // Start a goroutine to read updates off the agentpb. + eventCh := make(chan *agentpb.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + // We should get no snapshot and the first event should be "resume stream" + select { + case event := <-eventCh: + require.True(event.GetResumeStream()) + case <-time.After(500 * time.Millisecond): + t.Fatalf("never got event") + } + + // Wait and make sure there aren't any events coming. The server shouldn't send + // a snapshot and we haven't made any updates to the catalog that would trigger + // more events. + select { + case event := <-eventCh: + t.Fatalf("got another event: %v", event) + case <-time.After(500 * time.Millisecond): + } + } +} + +func TestStreaming_Subscribe_FilterACL(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir, _, server, codec := testACLFilterServerWithConfigFn(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + c.ACLDefaultPolicy = "deny" + c.ACLEnforceVersion8 = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir) + defer server.Shutdown() + defer codec.Close() + + dir2, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir2) + defer client.Shutdown() + + // Try to join + testrpc.WaitForLeader(t, server.RPC, "dc1") + joinLAN(t, client, server) + testrpc.WaitForTestAgent(t, client.RPC, "dc1", testrpc.WithToken("root")) + + // Create a policy for the test token. + policyReq := structs.ACLPolicySetRequest{ + Datacenter: "dc1", + Policy: structs.ACLPolicy{ + Description: "foobar", + Name: "baz", + Rules: fmt.Sprintf(` + service "foo" { + policy = "write" + } + node "%s" { + policy = "write" + } + `, server.config.NodeName), + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLPolicy{} + require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.PolicySet", &policyReq, &resp)) + + // Create a new token that only has access to one node. + var token structs.ACLToken + arg := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Policies: []structs.ACLTokenPolicyLink{ + structs.ACLTokenPolicyLink{ + ID: resp.ID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &token)) + auth, err := server.ResolveToken(token.SecretID) + require.NoError(err) + require.Equal(auth.NodeRead("denied", nil), acl.Deny) + + // Register another instance of service foo on a fake node the token doesn't have access to. + regArg := structs.RegisterRequest{ + Datacenter: "dc1", + Node: "denied", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "foo", + Service: "foo", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", ®Arg, nil)) + + // Set up the gRPC client. + conn, err := client.GRPCConn() + require.NoError(err) + streamClient := agentpb.NewConsulClient(conn) + + // Start a Subscribe call to our streaming endpoint for the service we have access to. + { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + streamHandle, err := streamClient.Subscribe(ctx, &agentpb.SubscribeRequest{ + Topic: agentpb.Topic_ServiceHealth, + Key: "foo", + Token: token.SecretID, + }) + require.NoError(err) + + // Start a goroutine to read updates off the agentpb. + eventCh := make(chan *agentpb.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + // Read events off the agentpb. We should not see any events for the filtered node. + var snapshotEvents []*agentpb.Event + for i := 0; i < 2; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(5 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + require.Len(snapshotEvents, 2) + require.Equal("foo", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Service.Service) + require.Equal(server.config.NodeName, snapshotEvents[0].GetServiceHealth().CheckServiceNode.Node.Node) + require.True(snapshotEvents[1].GetEndOfSnapshot()) + + // Update the service with a new port to trigger a new event. + regArg := structs.RegisterRequest{ + Datacenter: "dc1", + Node: server.config.NodeName, + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "foo", + Service: "foo", + Port: 1234, + }, + Check: &structs.HealthCheck{ + CheckID: "service:foo", + Name: "service:foo", + ServiceID: "foo", + Status: api.HealthPassing, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", ®Arg, nil)) + + select { + case event := <-eventCh: + service := event.GetServiceHealth().CheckServiceNode.Service + require.Equal("foo", service.Service) + require.Equal(1234, service.Port) + case <-time.After(5 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + + // Now update the service on the denied node and make sure we don't see an event. + regArg = structs.RegisterRequest{ + Datacenter: "dc1", + Node: "denied", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "foo", + Service: "foo", + Port: 2345, + }, + Check: &structs.HealthCheck{ + CheckID: "service:foo", + Name: "service:foo", + ServiceID: "foo", + Status: api.HealthPassing, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", ®Arg, nil)) + + select { + case event := <-eventCh: + t.Fatalf("should not have received event: %v", event) + case <-time.After(500 * time.Millisecond): + } + } + + // Start another subscribe call for bar, which the token shouldn't have access to. + { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + streamHandle, err := streamClient.Subscribe(ctx, &agentpb.SubscribeRequest{ + Topic: agentpb.Topic_ServiceHealth, + Key: "bar", + Token: token.SecretID, + }) + require.NoError(err) + + // Start a goroutine to read updates off the agentpb. + eventCh := make(chan *agentpb.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + select { + case event := <-eventCh: + require.True(event.GetEndOfSnapshot()) + case <-time.After(3 * time.Second): + t.Fatal("did not receive event") + } + + // Update the service and make sure we don't get a new event. + regArg := structs.RegisterRequest{ + Datacenter: "dc1", + Node: server.config.NodeName, + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "bar", + Service: "bar", + Port: 2345, + }, + Check: &structs.HealthCheck{ + CheckID: "service:bar", + Name: "service:bar", + ServiceID: "bar", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", ®Arg, nil)) + + select { + case event := <-eventCh: + t.Fatalf("should not have received event: %v", event) + case <-time.After(500 * time.Millisecond): + } + } +} + +func TestStreaming_Subscribe_ACLUpdate(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir, _, server, codec := testACLFilterServerWithConfigFn(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + c.ACLDefaultPolicy = "deny" + c.ACLEnforceVersion8 = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir) + defer server.Shutdown() + defer codec.Close() + + dir2, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir2) + defer client.Shutdown() + + // Try to join + testrpc.WaitForLeader(t, server.RPC, "dc1") + joinLAN(t, client, server) + testrpc.WaitForTestAgent(t, client.RPC, "dc1", testrpc.WithToken("root")) + + // Create a new token/policy that only has access to one node. + var token structs.ACLToken + + policy, err := upsertTestPolicyWithRules(codec, "root", "dc1", fmt.Sprintf(` + service "foo" { + policy = "write" + } + node "%s" { + policy = "write" + } + `, server.config.NodeName)) + require.NoError(err) + + arg := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "Service/node token", + Policies: []structs.ACLTokenPolicyLink{ + structs.ACLTokenPolicyLink{ + ID: policy.ID, + }, + }, + Local: false, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &token)) + auth, err := server.ResolveToken(token.SecretID) + require.NoError(err) + require.Equal(auth.NodeRead("denied", nil), acl.Deny) + + // Set up the gRPC client. + conn, err := client.GRPCConn() + require.NoError(err) + streamClient := agentpb.NewConsulClient(conn) + + // Start a Subscribe call to our streaming endpoint for the service we have access to. + { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + streamHandle, err := streamClient.Subscribe(ctx, &agentpb.SubscribeRequest{ + Topic: agentpb.Topic_ServiceHealth, + Key: "foo", + Token: token.SecretID, + }) + require.NoError(err) + + // Start a goroutine to read updates off the agentpb. + eventCh := make(chan *agentpb.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + // Read events off the agentpb. + var snapshotEvents []*agentpb.Event + for i := 0; i < 2; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(5 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + require.Len(snapshotEvents, 2) + require.Equal("foo", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Service.Service) + require.Equal(server.config.NodeName, snapshotEvents[0].GetServiceHealth().CheckServiceNode.Node.Node) + require.True(snapshotEvents[1].GetEndOfSnapshot()) + + // Update a different token and make sure we don't see an event. + arg2 := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "Ignored token", + Policies: []structs.ACLTokenPolicyLink{ + structs.ACLTokenPolicyLink{ + ID: policy.ID, + }, + }, + Local: false, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var ignoredToken structs.ACLToken + require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg2, &ignoredToken)) + + select { + case event := <-eventCh: + t.Fatalf("should not have received event: %v", event) + case <-time.After(500 * time.Millisecond): + } + + // Update our token to trigger a refresh event. + token.Policies = []structs.ACLTokenPolicyLink{} + arg := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: token, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &token)) + + select { + case event := <-eventCh: + require.True(event.GetResetStream()) + // 500 ms was not enough in CI apparently... + case <-time.After(2 * time.Second): + t.Fatalf("did not receive reload event") + } + } +} + +// testSendEvents receives agentpb.Events from a given handle and sends them to the provided +// channel. This is meant to be run in a separate goroutine from the main test. +func testSendEvents(t *testing.T, ch chan *agentpb.Event, handle agentpb.Consul_SubscribeClient) { + for { + event, err := handle.Recv() + if err == io.EOF { + break + } + if err != nil { + if strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "context canceled") { + break + } + t.Log(err) + } + ch <- event + } +} + +func TestStreaming_TLSEnabled(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir1, conf1 := testServerConfig(t) + conf1.VerifyIncoming = true + conf1.VerifyOutgoing = true + conf1.GRPCEnabled = true + configureTLS(conf1) + server, err := newServer(conf1) + if err != nil { + t.Fatalf("err: %v", err) + } + defer os.RemoveAll(dir1) + defer server.Shutdown() + + dir2, conf2 := testClientConfig(t) + conf2.VerifyOutgoing = true + conf2.GRPCEnabled = true + configureTLS(conf2) + client, err := NewClient(conf2) + if err != nil { + t.Fatalf("err: %v", err) + } + defer os.RemoveAll(dir2) + defer client.Shutdown() + + // Try to join + testrpc.WaitForLeader(t, server.RPC, "dc1") + joinLAN(t, client, server) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Register a dummy node with our service on it. + { + req := &structs.RegisterRequest{ + Node: "node1", + Address: "3.4.5.6", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + }, + } + var out struct{} + require.NoError(server.RPC("Catalog.Register", &req, &out)) + } + + // Start a Subscribe call to our streaming endpoint from the client. + { + conn, err := client.GRPCConn() + require.NoError(err) + + streamClient := agentpb.NewConsulClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + streamHandle, err := streamClient.Subscribe(ctx, &agentpb.SubscribeRequest{Topic: agentpb.Topic_ServiceHealth, Key: "redis"}) + require.NoError(err) + + // Start a goroutine to read updates off the agentpb. + eventCh := make(chan *agentpb.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + var snapshotEvents []*agentpb.Event + for i := 0; i < 2; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(3 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + + // Make sure the snapshot events come back with no issues. + require.Len(snapshotEvents, 2) + } + + // Start a Subscribe call to our streaming endpoint from the server's loopback client. + { + conn, err := server.GRPCConn() + require.NoError(err) + + retryFailedConn(t, conn) + + streamClient := agentpb.NewConsulClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + streamHandle, err := streamClient.Subscribe(ctx, &agentpb.SubscribeRequest{Topic: agentpb.Topic_ServiceHealth, Key: "redis"}) + require.NoError(err) + + // Start a goroutine to read updates off the agentpb. + eventCh := make(chan *agentpb.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + var snapshotEvents []*agentpb.Event + for i := 0; i < 2; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(3 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + + // Make sure the snapshot events come back with no issues. + require.Len(snapshotEvents, 2) + } +} + +func TestStreaming_TLSReload(t *testing.T) { + t.Parallel() + + // Set up a server with initially bad certificates. + require := require.New(t) + dir1, conf1 := testServerConfig(t) + conf1.VerifyIncoming = true + conf1.VerifyOutgoing = true + conf1.CAFile = "../../test/ca/root.cer" + conf1.CertFile = "../../test/key/ssl-cert-snakeoil.pem" + conf1.KeyFile = "../../test/key/ssl-cert-snakeoil.key" + conf1.GRPCEnabled = true + + server, err := newServer(conf1) + if err != nil { + t.Fatalf("err: %v", err) + } + defer os.RemoveAll(dir1) + defer server.Shutdown() + + // Set up a client with valid certs and verify_outgoing = true + dir2, conf2 := testClientConfig(t) + conf2.VerifyOutgoing = true + conf2.GRPCEnabled = true + configureTLS(conf2) + client, err := NewClient(conf2) + if err != nil { + t.Fatalf("err: %v", err) + } + defer os.RemoveAll(dir2) + defer client.Shutdown() + + testrpc.WaitForLeader(t, server.RPC, "dc1") + + // Subscribe calls should fail initially + joinLAN(t, client, server) + conn, err := client.GRPCConn() + require.NoError(err) + { + streamClient := agentpb.NewConsulClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _, err = streamClient.Subscribe(ctx, &agentpb.SubscribeRequest{Topic: agentpb.Topic_ServiceHealth, Key: "redis"}) + require.Error(err, "tls: bad certificate") + } + + // Reload the server with valid certs + newConf := server.config.ToTLSUtilConfig() + newConf.CertFile = "../../test/key/ourdomain.cer" + newConf.KeyFile = "../../test/key/ourdomain.key" + server.tlsConfigurator.Update(newConf) + + // Try the subscribe call again + { + retryFailedConn(t, conn) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + streamClient := agentpb.NewConsulClient(conn) + _, err = streamClient.Subscribe(ctx, &agentpb.SubscribeRequest{Topic: agentpb.Topic_ServiceHealth, Key: "redis"}) + require.NoError(err) + } +} + +// retryFailedConn forces the ClientConn to reset its backoff timer and retry the connection, +// to simulate the client eventually retrying after the initial failure. This is used both to simulate +// retrying after an expected failure as well as to avoid flakiness when running many tests in parallel. +func retryFailedConn(t *testing.T, conn *grpc.ClientConn) { + state := conn.GetState() + if state.String() != "TRANSIENT_FAILURE" { + return + } + + // If the connection has failed, retry and wait for a state change. + conn.ResetConnectBackoff() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + require.True(t, conn.WaitForStateChange(ctx, state)) +} + +func TestStreaming_DeliversAllMessages(t *testing.T) { + // This is a fuzz/probabilistic test to try to provoke streaming into dropping + // messages. There is a bug in the initial implementation that should make + // this fail. While we can't be certain a pass means it's correct, it is + // useful for finding bugs in our concurrency design. + + // The issue is that when updates are coming in fast such that updates occur + // in between us making the snapshot and beginning the stream updates, we + // shouldn't miss anything. + + // To test this, we will run a background goroutine that will write updates as + // fast as possible while we then try to stream the results and ensure that we + // see every change. We'll make the updates monotonically increasing so we can + // easily tell if we missed one. + + require := require.New(t) + dir1, server := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir1) + defer server.Shutdown() + codec := rpcClient(t, server) + defer codec.Close() + + dir2, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir2) + defer client.Shutdown() + + // Try to join + testrpc.WaitForLeader(t, server.RPC, "dc1") + joinLAN(t, client, server) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Register a whole bunch of service instances so that the initial snapshot on + // subscribe is big enough to take a bit of time to load giving more + // opportunity for missed updates if there is a bug. + for i := 0; i < 1000; i++ { + req := &structs.RegisterRequest{ + Node: fmt.Sprintf("node-redis-%03d", i), + Address: "3.4.5.6", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: fmt.Sprintf("redis-%03d", i), + Service: "redis", + Port: 11211, + }, + } + var out struct{} + require.NoError(server.RPC("Catalog.Register", &req, &out)) + } + + // Start background writer + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go func() { + // Update the registration with a monotonically increasing port as fast as + // we can. + req := &structs.RegisterRequest{ + Node: "node1", + Address: "3.4.5.6", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "redis-canary", + Service: "redis", + Port: 0, + }, + } + for { + if ctx.Err() != nil { + return + } + var out struct{} + require.NoError(server.RPC("Catalog.Register", &req, &out)) + req.Service.Port++ + if req.Service.Port > 100 { + return + } + time.Sleep(1 * time.Millisecond) + } + }() + + // Now start a whole bunch of streamers in parallel to maximise chance of + // catching a race. + conn, err := client.GRPCConn() + require.NoError(err) + + streamClient := agentpb.NewConsulClient(conn) + + n := 5 + var wg sync.WaitGroup + var updateCount uint64 + // Buffered error chan so that workers can exit and terminate wg without + // blocking on send. We collect errors this way since t isn't thread safe. + errCh := make(chan error, n) + for i := 0; i < n; i++ { + wg.Add(1) + go verifyMonotonicStreamUpdates(ctx, t, streamClient, &wg, i, &updateCount, errCh) + } + + // Wait until all subscribers have verified the first bunch of updates all got + // delivered. + wg.Wait() + + close(errCh) + + // Require that none of them errored. Since we closed the chan above this loop + // should terminate immediately if no errors were buffered. + for err := range errCh { + require.NoError(err) + } + + // Sanity check that at least some non-snapshot messages were delivered. We + // can't know exactly how many because it's timing dependent based on when + // each subscribers snapshot occurs. + require.True(atomic.LoadUint64(&updateCount) > 0, + "at least some of the subscribers should have received non-snapshot updates") +} + +type testLogger interface { + Logf(format string, args ...interface{}) +} + +func verifyMonotonicStreamUpdates(ctx context.Context, logger testLogger, client agentpb.ConsulClient, wg *sync.WaitGroup, i int, updateCount *uint64, errCh chan<- error) { + defer wg.Done() + streamHandle, err := client.Subscribe(ctx, &agentpb.SubscribeRequest{Topic: agentpb.Topic_ServiceHealth, Key: "redis"}) + if err != nil { + if strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "context canceled") { + logger.Logf("subscriber %05d: context cancelled before loop") + return + } + errCh <- err + return + } + + snapshotDone := false + expectPort := 0 + for { + event, err := streamHandle.Recv() + if err == io.EOF { + break + } + if err != nil { + if strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "context canceled") { + break + } + errCh <- err + return + } + + // Ignore snapshot message + if event.GetEndOfSnapshot() || event.GetResumeStream() { + snapshotDone = true + logger.Logf("subscriber %05d: snapshot done, expect next port to be %d", i, expectPort) + } else if snapshotDone { + // Verify we get all updates in order + svc, err := svcOrErr(event) + if err != nil { + errCh <- err + return + } + if expectPort != svc.Port { + errCh <- fmt.Errorf("subscriber %05d: missed %d update(s)!", i, svc.Port-expectPort) + return + } + atomic.AddUint64(updateCount, 1) + logger.Logf("subscriber %05d: got event with correct port=%d", i, expectPort) + expectPort++ + } else { + // This is a snapshot update. Check if it's an update for the canary + // instance that got applied before our snapshot was sent (likely) + svc, err := svcOrErr(event) + if err != nil { + errCh <- err + return + } + if svc.ID == "redis-canary" { + // Update the expected port we see in the next update to be one more + // than the port in the snapshot. + expectPort = svc.Port + 1 + logger.Logf("subscriber %05d: saw canary in snapshot with port %d", i, svc.Port) + } + } + if expectPort > 100 { + return + } + } +} + +func svcOrErr(event *agentpb.Event) (*agentpb.NodeService, error) { + health := event.GetServiceHealth() + if health == nil { + return nil, fmt.Errorf("not a health event: %#v", event) + } + csn := health.CheckServiceNode + if csn == nil { + return nil, fmt.Errorf("nil CSN: %#v", event) + } + if csn.Service == nil { + return nil, fmt.Errorf("nil service: %#v", event) + } + return csn.Service, nil +} diff --git a/agent/consul/test_endpoint.go b/agent/consul/test_grpc_endpoint.go similarity index 100% rename from agent/consul/test_endpoint.go rename to agent/consul/test_grpc_endpoint.go diff --git a/logging/names.go b/logging/names.go index 8dc62aa8e4..172876630f 100644 --- a/logging/names.go +++ b/logging/names.go @@ -45,6 +45,7 @@ const ( Session string = "session" Sentinel string = "sentinel" Snapshot string = "snapshot" + Subscribe string = "subscribe" TLSUtil string = "tlsutil" Transaction string = "txn" WAN string = "wan"