subscription filters
This commit is contained in:
parent
d6c20b59fc
commit
89f61abf29
20
pubsub.go
20
pubsub.go
|
@ -148,6 +148,9 @@ type PubSub struct {
|
|||
// strict mode rejects all unsigned messages prior to validation
|
||||
signPolicy MessageSignaturePolicy
|
||||
|
||||
// filter for tracking subscriptions in topics of interest; if nil, then we track all subscriptions
|
||||
subFilter SubscriptionFilter
|
||||
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
|
@ -900,8 +903,19 @@ func (p *PubSub) notifyLeave(topic string, pid peer.ID) {
|
|||
func (p *PubSub) handleIncomingRPC(rpc *RPC) {
|
||||
p.tracer.RecvRPC(rpc)
|
||||
|
||||
for _, subopt := range rpc.GetSubscriptions() {
|
||||
subs := rpc.GetSubscriptions()
|
||||
if len(subs) != 0 && p.subFilter != nil {
|
||||
var err error
|
||||
subs, err = p.subFilter.FilterIncomingSubscriptions(rpc.from, subs)
|
||||
if err != nil {
|
||||
log.Debugf("subscription filter error: %s; ignoring RPC", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, subopt := range subs {
|
||||
t := subopt.GetTopicid()
|
||||
|
||||
if subopt.GetSubscribe() {
|
||||
tmap, ok := p.topics[t]
|
||||
if !ok {
|
||||
|
@ -1073,6 +1087,10 @@ func (p *PubSub) Join(topic string, opts ...TopicOpt) (*Topic, error) {
|
|||
// Returns true if the topic was newly created, false otherwise
|
||||
// Can be removed once pubsub.Publish() and pubsub.Subscribe() are removed
|
||||
func (p *PubSub) tryJoin(topic string, opts ...TopicOpt) (*Topic, bool, error) {
|
||||
if p.subFilter != nil && !p.subFilter.CanSubscribe(topic) {
|
||||
return nil, false, fmt.Errorf("topic is not allowed by the subscription filter")
|
||||
}
|
||||
|
||||
t := &Topic{
|
||||
p: p,
|
||||
topic: topic,
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
package pubsub
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
|
||||
pb "github.com/libp2p/go-libp2p-pubsub/pb"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
)
|
||||
|
||||
// ErrTooManySubscriptions may be returned by a SubscriptionFilter to signal that there are too many
|
||||
// subscriptions to process.
|
||||
var ErrTooManySubscriptions = errors.New("too many subscriptions")
|
||||
|
||||
// SubscriptionFilter is a function that tells us whether we are interested in allowing and tracking
|
||||
// subscriptions for a given topic.
|
||||
//
|
||||
// The filter is consulted whenever a subscription notification is received by another peer; if the
|
||||
// filter returns false, then the notification is ignored.
|
||||
//
|
||||
// The filter is also consulted when joining topics; if the filter returns false, then the Join
|
||||
// operation will result in an error.
|
||||
type SubscriptionFilter interface {
|
||||
// CanSubscribe returns true if the topic is of interest and we can subscribe to it
|
||||
CanSubscribe(topic string) bool
|
||||
|
||||
// FilterIncomingSubscriptions is invoked for all RPCs containing subscription notifications.
|
||||
// It should filter only the subscriptions of interest and my return an error if (for instance)
|
||||
// there are too many subscriptions.
|
||||
FilterIncomingSubscriptions(peer.ID, []*pb.RPC_SubOpts) ([]*pb.RPC_SubOpts, error)
|
||||
}
|
||||
|
||||
// WithSubscriptionFilter is a pubsub option that specifies a filter for subscriptions
|
||||
// in topics of interest.
|
||||
func WithSubscriptionFilter(subFilter SubscriptionFilter) Option {
|
||||
return func(ps *PubSub) error {
|
||||
ps.subFilter = subFilter
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewAllowlistSubscriptionFilter creates a subscription filter that only allows explicitly
|
||||
// specified topics for local subscriptions and incoming peer subscriptions.
|
||||
func NewAllowlistSubscriptionFilter(topics ...string) SubscriptionFilter {
|
||||
allow := make(map[string]struct{})
|
||||
for _, topic := range topics {
|
||||
allow[topic] = struct{}{}
|
||||
}
|
||||
|
||||
return &allowlistSubscriptionFilter{allow: allow}
|
||||
}
|
||||
|
||||
type allowlistSubscriptionFilter struct {
|
||||
allow map[string]struct{}
|
||||
}
|
||||
|
||||
var _ SubscriptionFilter = (*allowlistSubscriptionFilter)(nil)
|
||||
|
||||
func (f *allowlistSubscriptionFilter) CanSubscribe(topic string) bool {
|
||||
_, ok := f.allow[topic]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (f *allowlistSubscriptionFilter) FilterIncomingSubscriptions(from peer.ID, subs []*pb.RPC_SubOpts) ([]*pb.RPC_SubOpts, error) {
|
||||
return FilterSubscriptions(subs, f.CanSubscribe), nil
|
||||
}
|
||||
|
||||
// NewRegexpSubscriptionFilter creates a subscription filter that only allows topics that
|
||||
// match a regular expression for local subscriptions and incoming peer subscriptions.
|
||||
//
|
||||
// Warning: the user should take care to match start/end of string in the supplied regular
|
||||
// expression, otherwise the filter might match unwanted topics unexpectedly.
|
||||
func NewRegexpSubscriptionFilter(rx *regexp.Regexp) SubscriptionFilter {
|
||||
return &rxSubscriptionFilter{allow: rx}
|
||||
}
|
||||
|
||||
type rxSubscriptionFilter struct {
|
||||
allow *regexp.Regexp
|
||||
}
|
||||
|
||||
var _ SubscriptionFilter = (*rxSubscriptionFilter)(nil)
|
||||
|
||||
func (f *rxSubscriptionFilter) CanSubscribe(topic string) bool {
|
||||
return f.allow.MatchString(topic)
|
||||
}
|
||||
|
||||
func (f *rxSubscriptionFilter) FilterIncomingSubscriptions(from peer.ID, subs []*pb.RPC_SubOpts) ([]*pb.RPC_SubOpts, error) {
|
||||
return FilterSubscriptions(subs, f.CanSubscribe), nil
|
||||
}
|
||||
|
||||
// FilterSubscriptions filters (and deduplicates) a list of subscriptions.
|
||||
// filter should return true if a topic is of interest.
|
||||
func FilterSubscriptions(subs []*pb.RPC_SubOpts, filter func(string) bool) []*pb.RPC_SubOpts {
|
||||
accept := make(map[string]*pb.RPC_SubOpts)
|
||||
|
||||
for _, sub := range subs {
|
||||
topic := sub.GetTopicid()
|
||||
|
||||
if !filter(topic) {
|
||||
continue
|
||||
}
|
||||
|
||||
otherSub, ok := accept[topic]
|
||||
if ok {
|
||||
if sub.GetSubscribe() != otherSub.GetSubscribe() {
|
||||
delete(accept, topic)
|
||||
}
|
||||
} else {
|
||||
accept[topic] = sub
|
||||
}
|
||||
}
|
||||
|
||||
if len(accept) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]*pb.RPC_SubOpts, 0, len(accept))
|
||||
for _, sub := range accept {
|
||||
result = append(result, sub)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// WrapLimitSubscriptionFilter wraps a subscription filter with a hard limit in the number of
|
||||
// subscriptions allowed in an RPC message.
|
||||
func WrapLimitSubscriptionFilter(filter SubscriptionFilter, limit int) SubscriptionFilter {
|
||||
return &limitSubscriptionFilter{filter: filter, limit: limit}
|
||||
}
|
||||
|
||||
type limitSubscriptionFilter struct {
|
||||
filter SubscriptionFilter
|
||||
limit int
|
||||
}
|
||||
|
||||
var _ SubscriptionFilter = (*limitSubscriptionFilter)(nil)
|
||||
|
||||
func (f *limitSubscriptionFilter) CanSubscribe(topic string) bool {
|
||||
return f.filter.CanSubscribe(topic)
|
||||
}
|
||||
|
||||
func (f *limitSubscriptionFilter) FilterIncomingSubscriptions(from peer.ID, subs []*pb.RPC_SubOpts) ([]*pb.RPC_SubOpts, error) {
|
||||
if len(subs) > f.limit {
|
||||
return nil, ErrTooManySubscriptions
|
||||
}
|
||||
|
||||
return f.filter.FilterIncomingSubscriptions(from, subs)
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pb "github.com/libp2p/go-libp2p-pubsub/pb"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
)
|
||||
|
||||
func TestBasicSubscriptionFilter(t *testing.T) {
|
||||
peerA := peer.ID("A")
|
||||
|
||||
topic1 := "test1"
|
||||
topic2 := "test2"
|
||||
topic3 := "test3"
|
||||
yes := true
|
||||
subs := []*pb.RPC_SubOpts{
|
||||
&pb.RPC_SubOpts{
|
||||
Topicid: &topic1,
|
||||
Subscribe: &yes,
|
||||
},
|
||||
&pb.RPC_SubOpts{
|
||||
Topicid: &topic2,
|
||||
Subscribe: &yes,
|
||||
},
|
||||
&pb.RPC_SubOpts{
|
||||
Topicid: &topic3,
|
||||
Subscribe: &yes,
|
||||
},
|
||||
}
|
||||
|
||||
filter := NewAllowlistSubscriptionFilter(topic1, topic2)
|
||||
canSubscribe := filter.CanSubscribe(topic1)
|
||||
if !canSubscribe {
|
||||
t.Fatal("expected allowed subscription")
|
||||
}
|
||||
canSubscribe = filter.CanSubscribe(topic2)
|
||||
if !canSubscribe {
|
||||
t.Fatal("expected allowed subscription")
|
||||
}
|
||||
canSubscribe = filter.CanSubscribe(topic3)
|
||||
if canSubscribe {
|
||||
t.Fatal("expected disallowed subscription")
|
||||
}
|
||||
allowedSubs, err := filter.FilterIncomingSubscriptions(peerA, subs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(allowedSubs) != 2 {
|
||||
t.Fatalf("expected 2 allowed subscriptions but got %d", len(allowedSubs))
|
||||
}
|
||||
for _, sub := range allowedSubs {
|
||||
if sub.GetTopicid() == topic3 {
|
||||
t.Fatal("unpexted subscription to test3")
|
||||
}
|
||||
}
|
||||
|
||||
limitFilter := WrapLimitSubscriptionFilter(filter, 2)
|
||||
_, err = limitFilter.FilterIncomingSubscriptions(peerA, subs)
|
||||
if err != ErrTooManySubscriptions {
|
||||
t.Fatal("expected rejection because of too many subscriptions")
|
||||
}
|
||||
|
||||
filter = NewRegexpSubscriptionFilter(regexp.MustCompile("test[12]"))
|
||||
canSubscribe = filter.CanSubscribe(topic1)
|
||||
if !canSubscribe {
|
||||
t.Fatal("expected allowed subscription")
|
||||
}
|
||||
canSubscribe = filter.CanSubscribe(topic2)
|
||||
if !canSubscribe {
|
||||
t.Fatal("expected allowed subscription")
|
||||
}
|
||||
canSubscribe = filter.CanSubscribe(topic3)
|
||||
if canSubscribe {
|
||||
t.Fatal("expected disallowed subscription")
|
||||
}
|
||||
allowedSubs, err = filter.FilterIncomingSubscriptions(peerA, subs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(allowedSubs) != 2 {
|
||||
t.Fatalf("expected 2 allowed subscriptions but got %d", len(allowedSubs))
|
||||
}
|
||||
for _, sub := range allowedSubs {
|
||||
if sub.GetTopicid() == topic3 {
|
||||
t.Fatal("unexpected subscription")
|
||||
}
|
||||
}
|
||||
|
||||
limitFilter = WrapLimitSubscriptionFilter(filter, 2)
|
||||
_, err = limitFilter.FilterIncomingSubscriptions(peerA, subs)
|
||||
if err != ErrTooManySubscriptions {
|
||||
t.Fatal("expected rejection because of too many subscriptions")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSubscriptionFilterDeduplication(t *testing.T) {
|
||||
peerA := peer.ID("A")
|
||||
|
||||
topic1 := "test1"
|
||||
topic2 := "test2"
|
||||
topic3 := "test3"
|
||||
yes := true
|
||||
no := false
|
||||
subs := []*pb.RPC_SubOpts{
|
||||
&pb.RPC_SubOpts{
|
||||
Topicid: &topic1,
|
||||
Subscribe: &yes,
|
||||
},
|
||||
&pb.RPC_SubOpts{
|
||||
Topicid: &topic1,
|
||||
Subscribe: &yes,
|
||||
},
|
||||
|
||||
&pb.RPC_SubOpts{
|
||||
Topicid: &topic2,
|
||||
Subscribe: &yes,
|
||||
},
|
||||
&pb.RPC_SubOpts{
|
||||
Topicid: &topic2,
|
||||
Subscribe: &no,
|
||||
},
|
||||
&pb.RPC_SubOpts{
|
||||
Topicid: &topic3,
|
||||
Subscribe: &yes,
|
||||
},
|
||||
}
|
||||
|
||||
filter := NewAllowlistSubscriptionFilter(topic1, topic2)
|
||||
allowedSubs, err := filter.FilterIncomingSubscriptions(peerA, subs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(allowedSubs) != 1 {
|
||||
t.Fatalf("expected 2 allowed subscriptions but got %d", len(allowedSubs))
|
||||
}
|
||||
for _, sub := range allowedSubs {
|
||||
if sub.GetTopicid() == topic3 || sub.GetTopicid() == topic2 {
|
||||
t.Fatal("unexpected subscription")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionFilterRPC(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
hosts := getNetHosts(t, ctx, 2)
|
||||
ps1 := getPubsub(ctx, hosts[0], WithSubscriptionFilter(NewAllowlistSubscriptionFilter("test1", "test2")))
|
||||
ps2 := getPubsub(ctx, hosts[1], WithSubscriptionFilter(NewAllowlistSubscriptionFilter("test2", "test3")))
|
||||
|
||||
_ = mustSubscribe(t, ps1, "test1")
|
||||
_ = mustSubscribe(t, ps1, "test2")
|
||||
_ = mustSubscribe(t, ps2, "test2")
|
||||
_ = mustSubscribe(t, ps2, "test3")
|
||||
|
||||
// check the rejection as well
|
||||
_, err := ps1.Join("test3")
|
||||
if err == nil {
|
||||
t.Fatal("expected subscription error")
|
||||
}
|
||||
|
||||
connect(t, hosts[0], hosts[1])
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
var sub1, sub2, sub3 bool
|
||||
ready := make(chan struct{})
|
||||
|
||||
ps1.eval <- func() {
|
||||
_, sub1 = ps1.topics["test1"][hosts[1].ID()]
|
||||
_, sub2 = ps1.topics["test2"][hosts[1].ID()]
|
||||
_, sub3 = ps1.topics["test3"][hosts[1].ID()]
|
||||
ready <- struct{}{}
|
||||
}
|
||||
<-ready
|
||||
|
||||
if sub1 {
|
||||
t.Fatal("expected no subscription for test1")
|
||||
}
|
||||
if !sub2 {
|
||||
t.Fatal("expected subscription for test2")
|
||||
}
|
||||
if sub3 {
|
||||
t.Fatal("expected no subscription for test1")
|
||||
}
|
||||
|
||||
ps2.eval <- func() {
|
||||
_, sub1 = ps2.topics["test1"][hosts[0].ID()]
|
||||
_, sub2 = ps2.topics["test2"][hosts[0].ID()]
|
||||
_, sub3 = ps2.topics["test3"][hosts[0].ID()]
|
||||
ready <- struct{}{}
|
||||
}
|
||||
<-ready
|
||||
|
||||
if sub1 {
|
||||
t.Fatal("expected no subscription for test1")
|
||||
}
|
||||
if !sub2 {
|
||||
t.Fatal("expected subscription for test1")
|
||||
}
|
||||
if sub3 {
|
||||
t.Fatal("expected no subscription for test1")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue