consul/agent/subscribe/subscribe.go

261 lines
6.7 KiB
Go
Raw Normal View History

package subscribe
import (
"errors"
"fmt"
"github.com/hashicorp/go-uuid"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/proto/pbservice"
"github.com/hashicorp/consul/proto/pbsubscribe"
)
// Server implements a StateChangeSubscriptionServer for accepting SubscribeRequests,
// and sending events to the subscription topic.
type Server struct {
Backend Backend
Logger Logger
}
type Logger interface {
IsTrace() bool
Trace(msg string, args ...interface{})
}
var _ pbsubscribe.StateChangeSubscriptionServer = (*Server)(nil)
type Backend interface {
ResolveToken(token string) (acl.Authorizer, error)
Forward(dc string, f func(*grpc.ClientConn) error) (handled bool, err error)
Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error)
}
func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsubscribe.StateChangeSubscription_SubscribeServer) error {
// streamID is just used for message correlation in trace logs and not
// populated normally.
var streamID string
if h.Logger.IsTrace() {
// TODO(banks) it might be nice one day to replace this with OpenTracing ID
// if one is set etc. but probably pointless until we support that properly
// in other places so it's actually propagated properly. For now this just
// makes lifetime of a stream more traceable in our regular server logs for
// debugging/dev.
var err error
streamID, err = uuid.GenerateUUID()
if err != nil {
return err
}
}
// TODO: add fields to logger and pass logger around instead of streamID
handled, err := h.Backend.Forward(req.Datacenter, h.forwardToDC(req, serverStream, streamID))
if handled || err != nil {
return err
}
h.Logger.Trace("new subscription",
"topic", req.Topic.String(),
"key", req.Key,
"index", req.Index,
"stream_id", streamID,
)
var sentCount uint64
defer h.Logger.Trace("subscription closed", "stream_id", streamID)
// Resolve the token and create the ACL filter.
// TODO: handle token expiry gracefully...
authz, err := h.Backend.ResolveToken(req.Token)
if err != nil {
return err
}
sub, err := h.Backend.Subscribe(toStreamSubscribeRequest(req))
if err != nil {
return err
}
defer sub.Unsubscribe()
ctx := serverStream.Context()
snapshotDone := false
for {
events, err := sub.Next(ctx)
switch {
// TODO: test case
case errors.Is(err, stream.ErrSubscriptionClosed):
h.Logger.Trace("subscription reset by server", "stream_id", streamID)
return status.Error(codes.Aborted, err.Error())
case err != nil:
return err
}
events = filterStreamEvents(authz, events)
if len(events) == 0 {
continue
}
first := events[0]
switch {
case first.IsEndOfSnapshot() || first.IsEndOfEmptySnapshot():
snapshotDone = true
h.Logger.Trace("snapshot complete",
"index", first.Index, "sent", sentCount, "stream_id", streamID)
case snapshotDone:
h.Logger.Trace("sending events",
"index", first.Index,
"sent", sentCount,
"batch_size", len(events),
"stream_id", streamID,
)
}
sentCount += uint64(len(events))
e := newEventFromStreamEvents(req, events)
if err := serverStream.Send(e); err != nil {
return err
}
}
}
// TODO: can be replaced by mog conversion
func toStreamSubscribeRequest(req *pbsubscribe.SubscribeRequest) *stream.SubscribeRequest {
return &stream.SubscribeRequest{
Topic: req.Topic,
Key: req.Key,
Token: req.Token,
Index: req.Index,
}
}
func (h *Server) forwardToDC(
req *pbsubscribe.SubscribeRequest,
serverStream pbsubscribe.StateChangeSubscription_SubscribeServer,
streamID string,
) func(conn *grpc.ClientConn) error {
return func(conn *grpc.ClientConn) error {
h.Logger.Trace("forwarding to another DC",
"dc", req.Datacenter,
"topic", req.Topic.String(),
"key", req.Key,
"index", req.Index,
"stream_id", streamID,
)
defer func() {
h.Logger.Trace("forwarded stream closed",
"dc", req.Datacenter,
"stream_id", streamID,
)
}()
client := pbsubscribe.NewStateChangeSubscriptionClient(conn)
streamHandle, err := client.Subscribe(serverStream.Context(), req)
if err != nil {
return err
}
for {
event, err := streamHandle.Recv()
if err != nil {
return err
}
if err := serverStream.Send(event); err != nil {
return err
}
}
}
}
// filterStreamEvents to only those allowed by the acl token.
func filterStreamEvents(authz acl.Authorizer, events []stream.Event) []stream.Event {
// TODO: when is authz nil?
if authz == nil || len(events) == 0 {
return events
}
// Fast path for the common case of only 1 event since we can avoid slice
// allocation in the hot path of every single update event delivered in vast
// majority of cases with this. Note that this is called _per event/item_ when
// sending snapshots which is a lot worse than being called once on regular
// result.
if len(events) == 1 {
if enforceACL(authz, events[0]) == acl.Allow {
return events
}
return nil
}
var filtered []stream.Event
for idx := range events {
event := events[idx]
if enforceACL(authz, event) == acl.Allow {
filtered = append(filtered, event)
}
}
return filtered
}
func newEventFromStreamEvents(req *pbsubscribe.SubscribeRequest, events []stream.Event) *pbsubscribe.Event {
e := &pbsubscribe.Event{
Topic: req.Topic,
Key: req.Key,
Index: events[0].Index,
}
if len(events) == 1 {
event := events[0]
// TODO: refactor so these are only checked once, instead of 3 times.
switch {
case event.IsEndOfSnapshot():
e.Payload = &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true}
return e
case event.IsEndOfEmptySnapshot():
e.Payload = &pbsubscribe.Event_EndOfEmptySnapshot{EndOfEmptySnapshot: true}
return e
}
setPayload(e, event.Payload)
return e
}
e.Payload = &pbsubscribe.Event_EventBatch{
EventBatch: &pbsubscribe.EventBatch{
Events: batchEventsFromEventSlice(events),
},
}
return e
}
func setPayload(e *pbsubscribe.Event, payload interface{}) {
switch p := payload.(type) {
case state.EventPayloadCheckServiceNode:
e.Payload = &pbsubscribe.Event_ServiceHealth{
ServiceHealth: &pbsubscribe.ServiceHealthUpdate{
Op: p.Op,
// TODO: this could be cached
CheckServiceNode: pbservice.NewCheckServiceNodeFromStructs(p.Value),
},
}
default:
panic(fmt.Sprintf("unexpected payload: %T: %#v", p, p))
}
}
func batchEventsFromEventSlice(events []stream.Event) []*pbsubscribe.Event {
result := make([]*pbsubscribe.Event, len(events))
for i := range events {
event := events[i]
result[i] = &pbsubscribe.Event{Key: event.Key, Index: event.Index}
setPayload(result[i], event.Payload)
}
return result
}