diff --git a/pubsub.go b/pubsub.go index beb2012..6f898a2 100644 --- a/pubsub.go +++ b/pubsub.go @@ -148,6 +148,9 @@ type PubSub struct { // strict mode rejects all unsigned messages prior to validation signPolicy MessageSignaturePolicy + // filter for tracking subscriptions in topics of interest; if nil, then we track all subscriptions + subFilter SubscriptionFilter + ctx context.Context } @@ -900,8 +903,19 @@ func (p *PubSub) notifyLeave(topic string, pid peer.ID) { func (p *PubSub) handleIncomingRPC(rpc *RPC) { p.tracer.RecvRPC(rpc) - for _, subopt := range rpc.GetSubscriptions() { + subs := rpc.GetSubscriptions() + if len(subs) != 0 && p.subFilter != nil { + var err error + subs, err = p.subFilter.FilterIncomingSubscriptions(rpc.from, subs) + if err != nil { + log.Debugf("subscription filter error: %s; ignoring RPC", err) + return + } + } + + for _, subopt := range subs { t := subopt.GetTopicid() + if subopt.GetSubscribe() { tmap, ok := p.topics[t] if !ok { @@ -1073,6 +1087,10 @@ func (p *PubSub) Join(topic string, opts ...TopicOpt) (*Topic, error) { // Returns true if the topic was newly created, false otherwise // Can be removed once pubsub.Publish() and pubsub.Subscribe() are removed func (p *PubSub) tryJoin(topic string, opts ...TopicOpt) (*Topic, bool, error) { + if p.subFilter != nil && !p.subFilter.CanSubscribe(topic) { + return nil, false, fmt.Errorf("topic is not allowed by the subscription filter") + } + t := &Topic{ p: p, topic: topic, diff --git a/subscription_filter.go b/subscription_filter.go new file mode 100644 index 0000000..76e6eaa --- /dev/null +++ b/subscription_filter.go @@ -0,0 +1,149 @@ +package pubsub + +import ( + "errors" + "regexp" + + pb "github.com/libp2p/go-libp2p-pubsub/pb" + + "github.com/libp2p/go-libp2p-core/peer" +) + +// ErrTooManySubscriptions may be returned by a SubscriptionFilter to signal that there are too many +// subscriptions to process. +var ErrTooManySubscriptions = errors.New("too many subscriptions") + +// SubscriptionFilter is a function that tells us whether we are interested in allowing and tracking +// subscriptions for a given topic. +// +// The filter is consulted whenever a subscription notification is received by another peer; if the +// filter returns false, then the notification is ignored. +// +// The filter is also consulted when joining topics; if the filter returns false, then the Join +// operation will result in an error. +type SubscriptionFilter interface { + // CanSubscribe returns true if the topic is of interest and we can subscribe to it + CanSubscribe(topic string) bool + + // FilterIncomingSubscriptions is invoked for all RPCs containing subscription notifications. + // It should filter only the subscriptions of interest and my return an error if (for instance) + // there are too many subscriptions. + FilterIncomingSubscriptions(peer.ID, []*pb.RPC_SubOpts) ([]*pb.RPC_SubOpts, error) +} + +// WithSubscriptionFilter is a pubsub option that specifies a filter for subscriptions +// in topics of interest. +func WithSubscriptionFilter(subFilter SubscriptionFilter) Option { + return func(ps *PubSub) error { + ps.subFilter = subFilter + return nil + } +} + +// NewAllowlistSubscriptionFilter creates a subscription filter that only allows explicitly +// specified topics for local subscriptions and incoming peer subscriptions. +func NewAllowlistSubscriptionFilter(topics ...string) SubscriptionFilter { + allow := make(map[string]struct{}) + for _, topic := range topics { + allow[topic] = struct{}{} + } + + return &allowlistSubscriptionFilter{allow: allow} +} + +type allowlistSubscriptionFilter struct { + allow map[string]struct{} +} + +var _ SubscriptionFilter = (*allowlistSubscriptionFilter)(nil) + +func (f *allowlistSubscriptionFilter) CanSubscribe(topic string) bool { + _, ok := f.allow[topic] + return ok +} + +func (f *allowlistSubscriptionFilter) FilterIncomingSubscriptions(from peer.ID, subs []*pb.RPC_SubOpts) ([]*pb.RPC_SubOpts, error) { + return FilterSubscriptions(subs, f.CanSubscribe), nil +} + +// NewRegexpSubscriptionFilter creates a subscription filter that only allows topics that +// match a regular expression for local subscriptions and incoming peer subscriptions. +// +// Warning: the user should take care to match start/end of string in the supplied regular +// expression, otherwise the filter might match unwanted topics unexpectedly. +func NewRegexpSubscriptionFilter(rx *regexp.Regexp) SubscriptionFilter { + return &rxSubscriptionFilter{allow: rx} +} + +type rxSubscriptionFilter struct { + allow *regexp.Regexp +} + +var _ SubscriptionFilter = (*rxSubscriptionFilter)(nil) + +func (f *rxSubscriptionFilter) CanSubscribe(topic string) bool { + return f.allow.MatchString(topic) +} + +func (f *rxSubscriptionFilter) FilterIncomingSubscriptions(from peer.ID, subs []*pb.RPC_SubOpts) ([]*pb.RPC_SubOpts, error) { + return FilterSubscriptions(subs, f.CanSubscribe), nil +} + +// FilterSubscriptions filters (and deduplicates) a list of subscriptions. +// filter should return true if a topic is of interest. +func FilterSubscriptions(subs []*pb.RPC_SubOpts, filter func(string) bool) []*pb.RPC_SubOpts { + accept := make(map[string]*pb.RPC_SubOpts) + + for _, sub := range subs { + topic := sub.GetTopicid() + + if !filter(topic) { + continue + } + + otherSub, ok := accept[topic] + if ok { + if sub.GetSubscribe() != otherSub.GetSubscribe() { + delete(accept, topic) + } + } else { + accept[topic] = sub + } + } + + if len(accept) == 0 { + return nil + } + + result := make([]*pb.RPC_SubOpts, 0, len(accept)) + for _, sub := range accept { + result = append(result, sub) + } + + return result +} + +// WrapLimitSubscriptionFilter wraps a subscription filter with a hard limit in the number of +// subscriptions allowed in an RPC message. +func WrapLimitSubscriptionFilter(filter SubscriptionFilter, limit int) SubscriptionFilter { + return &limitSubscriptionFilter{filter: filter, limit: limit} +} + +type limitSubscriptionFilter struct { + filter SubscriptionFilter + limit int +} + +var _ SubscriptionFilter = (*limitSubscriptionFilter)(nil) + +func (f *limitSubscriptionFilter) CanSubscribe(topic string) bool { + return f.filter.CanSubscribe(topic) +} + +func (f *limitSubscriptionFilter) FilterIncomingSubscriptions(from peer.ID, subs []*pb.RPC_SubOpts) ([]*pb.RPC_SubOpts, error) { + if len(subs) > f.limit { + return nil, ErrTooManySubscriptions + } + + return f.filter.FilterIncomingSubscriptions(from, subs) +} diff --git a/subscription_filter_test.go b/subscription_filter_test.go new file mode 100644 index 0000000..a241371 --- /dev/null +++ b/subscription_filter_test.go @@ -0,0 +1,210 @@ +package pubsub + +import ( + "context" + "regexp" + "testing" + "time" + + pb "github.com/libp2p/go-libp2p-pubsub/pb" + + "github.com/libp2p/go-libp2p-core/peer" +) + +func TestBasicSubscriptionFilter(t *testing.T) { + peerA := peer.ID("A") + + topic1 := "test1" + topic2 := "test2" + topic3 := "test3" + yes := true + subs := []*pb.RPC_SubOpts{ + &pb.RPC_SubOpts{ + Topicid: &topic1, + Subscribe: &yes, + }, + &pb.RPC_SubOpts{ + Topicid: &topic2, + Subscribe: &yes, + }, + &pb.RPC_SubOpts{ + Topicid: &topic3, + Subscribe: &yes, + }, + } + + filter := NewAllowlistSubscriptionFilter(topic1, topic2) + canSubscribe := filter.CanSubscribe(topic1) + if !canSubscribe { + t.Fatal("expected allowed subscription") + } + canSubscribe = filter.CanSubscribe(topic2) + if !canSubscribe { + t.Fatal("expected allowed subscription") + } + canSubscribe = filter.CanSubscribe(topic3) + if canSubscribe { + t.Fatal("expected disallowed subscription") + } + allowedSubs, err := filter.FilterIncomingSubscriptions(peerA, subs) + if err != nil { + t.Fatal(err) + } + if len(allowedSubs) != 2 { + t.Fatalf("expected 2 allowed subscriptions but got %d", len(allowedSubs)) + } + for _, sub := range allowedSubs { + if sub.GetTopicid() == topic3 { + t.Fatal("unpexted subscription to test3") + } + } + + limitFilter := WrapLimitSubscriptionFilter(filter, 2) + _, err = limitFilter.FilterIncomingSubscriptions(peerA, subs) + if err != ErrTooManySubscriptions { + t.Fatal("expected rejection because of too many subscriptions") + } + + filter = NewRegexpSubscriptionFilter(regexp.MustCompile("test[12]")) + canSubscribe = filter.CanSubscribe(topic1) + if !canSubscribe { + t.Fatal("expected allowed subscription") + } + canSubscribe = filter.CanSubscribe(topic2) + if !canSubscribe { + t.Fatal("expected allowed subscription") + } + canSubscribe = filter.CanSubscribe(topic3) + if canSubscribe { + t.Fatal("expected disallowed subscription") + } + allowedSubs, err = filter.FilterIncomingSubscriptions(peerA, subs) + if err != nil { + t.Fatal(err) + } + if len(allowedSubs) != 2 { + t.Fatalf("expected 2 allowed subscriptions but got %d", len(allowedSubs)) + } + for _, sub := range allowedSubs { + if sub.GetTopicid() == topic3 { + t.Fatal("unexpected subscription") + } + } + + limitFilter = WrapLimitSubscriptionFilter(filter, 2) + _, err = limitFilter.FilterIncomingSubscriptions(peerA, subs) + if err != ErrTooManySubscriptions { + t.Fatal("expected rejection because of too many subscriptions") + } + +} + +func TestSubscriptionFilterDeduplication(t *testing.T) { + peerA := peer.ID("A") + + topic1 := "test1" + topic2 := "test2" + topic3 := "test3" + yes := true + no := false + subs := []*pb.RPC_SubOpts{ + &pb.RPC_SubOpts{ + Topicid: &topic1, + Subscribe: &yes, + }, + &pb.RPC_SubOpts{ + Topicid: &topic1, + Subscribe: &yes, + }, + + &pb.RPC_SubOpts{ + Topicid: &topic2, + Subscribe: &yes, + }, + &pb.RPC_SubOpts{ + Topicid: &topic2, + Subscribe: &no, + }, + &pb.RPC_SubOpts{ + Topicid: &topic3, + Subscribe: &yes, + }, + } + + filter := NewAllowlistSubscriptionFilter(topic1, topic2) + allowedSubs, err := filter.FilterIncomingSubscriptions(peerA, subs) + if err != nil { + t.Fatal(err) + } + if len(allowedSubs) != 1 { + t.Fatalf("expected 2 allowed subscriptions but got %d", len(allowedSubs)) + } + for _, sub := range allowedSubs { + if sub.GetTopicid() == topic3 || sub.GetTopicid() == topic2 { + t.Fatal("unexpected subscription") + } + } +} + +func TestSubscriptionFilterRPC(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + ps1 := getPubsub(ctx, hosts[0], WithSubscriptionFilter(NewAllowlistSubscriptionFilter("test1", "test2"))) + ps2 := getPubsub(ctx, hosts[1], WithSubscriptionFilter(NewAllowlistSubscriptionFilter("test2", "test3"))) + + _ = mustSubscribe(t, ps1, "test1") + _ = mustSubscribe(t, ps1, "test2") + _ = mustSubscribe(t, ps2, "test2") + _ = mustSubscribe(t, ps2, "test3") + + // check the rejection as well + _, err := ps1.Join("test3") + if err == nil { + t.Fatal("expected subscription error") + } + + connect(t, hosts[0], hosts[1]) + + time.Sleep(time.Second) + + var sub1, sub2, sub3 bool + ready := make(chan struct{}) + + ps1.eval <- func() { + _, sub1 = ps1.topics["test1"][hosts[1].ID()] + _, sub2 = ps1.topics["test2"][hosts[1].ID()] + _, sub3 = ps1.topics["test3"][hosts[1].ID()] + ready <- struct{}{} + } + <-ready + + if sub1 { + t.Fatal("expected no subscription for test1") + } + if !sub2 { + t.Fatal("expected subscription for test2") + } + if sub3 { + t.Fatal("expected no subscription for test1") + } + + ps2.eval <- func() { + _, sub1 = ps2.topics["test1"][hosts[0].ID()] + _, sub2 = ps2.topics["test2"][hosts[0].ID()] + _, sub3 = ps2.topics["test3"][hosts[0].ID()] + ready <- struct{}{} + } + <-ready + + if sub1 { + t.Fatal("expected no subscription for test1") + } + if !sub2 { + t.Fatal("expected subscription for test1") + } + if sub3 { + t.Fatal("expected no subscription for test1") + } +}