package peering import ( "context" "crypto/sha256" "encoding/hex" "fmt" "strings" "github.com/golang/protobuf/proto" "github.com/hashicorp/go-hclog" "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/proto/pbservice" ) // subscriptionState is a collection of working state tied to a peerID subscription. type subscriptionState struct { // peerName is immutable and is the LOCAL name for the peering peerName string // partition is immutable partition string // plain data exportList *structs.ExportedServiceList watchedServices map[structs.ServiceName]context.CancelFunc watchedProxyServices map[structs.ServiceName]context.CancelFunc // TODO(peering): remove connectServices map[structs.ServiceName]string // value:protocol // eventVersions is a duplicate event suppression system keyed by the "id" // not the "correlationID" eventVersions map[string]string meshGateway *pbservice.IndexedCheckServiceNodes // updateCh is an internal implementation detail for the machinery of the // manager. updateCh chan<- cache.UpdateEvent // publicUpdateCh is the channel the manager uses to pass data back to the // caller. publicUpdateCh chan<- cache.UpdateEvent } func newSubscriptionState(peerName, partition string) *subscriptionState { return &subscriptionState{ peerName: peerName, partition: partition, watchedServices: make(map[structs.ServiceName]context.CancelFunc), watchedProxyServices: make(map[structs.ServiceName]context.CancelFunc), connectServices: make(map[structs.ServiceName]string), eventVersions: make(map[string]string), } } func (s *subscriptionState) sendPendingEvents( ctx context.Context, logger hclog.Logger, pending *pendingPayload, ) { for _, pendingEvt := range pending.Events { cID := pendingEvt.CorrelationID newVersion := pendingEvt.Version oldVersion, ok := s.eventVersions[pendingEvt.ID] if ok && newVersion == oldVersion { logger.Trace("skipping send of duplicate public event", "correlationID", cID) continue } logger.Trace("sending public event", "correlationID", cID) s.eventVersions[pendingEvt.ID] = newVersion evt := cache.UpdateEvent{ CorrelationID: cID, Result: pendingEvt.Result, } select { case s.publicUpdateCh <- evt: case <-ctx.Done(): } } } func (s *subscriptionState) cleanupEventVersions(logger hclog.Logger) { for id := range s.eventVersions { keep := false switch { case id == meshGatewayPayloadID: keep = true case strings.HasPrefix(id, servicePayloadIDPrefix): name := strings.TrimPrefix(id, servicePayloadIDPrefix) sn := structs.ServiceNameFromString(name) if _, ok := s.watchedServices[sn]; ok { keep = true } case strings.HasPrefix(id, proxyServicePayloadIDPrefix): name := strings.TrimPrefix(id, proxyServicePayloadIDPrefix) sn := structs.ServiceNameFromString(name) if _, ok := s.watchedProxyServices[sn]; ok { keep = true } case strings.HasPrefix(id, discoveryChainPayloadIDPrefix): name := strings.TrimPrefix(id, discoveryChainPayloadIDPrefix) sn := structs.ServiceNameFromString(name) if _, ok := s.connectServices[sn]; ok { keep = true } } if !keep { logger.Trace("cleaning up unreferenced event id version", "id", id) delete(s.eventVersions, id) } } } type pendingPayload struct { Events []pendingEvent } type pendingEvent struct { ID string CorrelationID string Result proto.Message Version string } const ( meshGatewayPayloadID = "mesh-gateway" servicePayloadIDPrefix = "service:" proxyServicePayloadIDPrefix = "proxy-service:" // TODO(peering): remove discoveryChainPayloadIDPrefix = "chain:" ) func (p *pendingPayload) Add(id string, correlationID string, raw interface{}) error { result, ok := raw.(proto.Message) if !ok { return fmt.Errorf("invalid type for %q event: %T", correlationID, raw) } version, err := hashProtobuf(result) if err != nil { return fmt.Errorf("error hashing %q event: %w", correlationID, err) } p.Events = append(p.Events, pendingEvent{ ID: id, CorrelationID: correlationID, Result: result, Version: version, }) return nil } func hashProtobuf(res proto.Message) (string, error) { h := sha256.New() buffer := proto.NewBuffer(nil) buffer.SetDeterministic(true) err := buffer.Marshal(res) if err != nil { return "", err } h.Write(buffer.Bytes()) buffer.Reset() return hex.EncodeToString(h.Sum(nil)), nil }