From 647bb9865583b0557f4a6851ccf7ae3910cb94c4 Mon Sep 17 00:00:00 2001 From: keks Date: Wed, 8 Nov 2017 20:00:52 +0100 Subject: [PATCH] optionally allow caller to validate messages --- floodsub.go | 113 +++++++++++++++++++++++++++++++++++++++-------- floodsub_test.go | 43 +++++++++++++++++- subscription.go | 1 + 3 files changed, 137 insertions(+), 20 deletions(-) diff --git a/floodsub.go b/floodsub.go index 52a93b9..8bfa79c 100644 --- a/floodsub.go +++ b/floodsub.go @@ -54,6 +54,9 @@ type PubSub struct { // topics tracks which topics each of our peers are subscribed to topics map[string]map[peer.ID]struct{} + // sendMsg handles messages that have been validated + sendMsg chan sendReq + peers map[peer.ID]chan *RPC seenMessages *timecache.TimeCache @@ -91,6 +94,7 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub { getPeers: make(chan *listPeerReq), addSub: make(chan *addSubReq), getTopics: make(chan *topicReq), + sendMsg: make(chan sendReq), myTopics: make(map[string]map[*Subscription]struct{}), topics: make(map[string]map[peer.ID]struct{}), peers: make(map[peer.ID]chan *RPC), @@ -176,7 +180,19 @@ func (p *PubSub) processLoop(ctx context.Context) { continue } case msg := <-p.publish: - p.maybePublishMessage(p.host.ID(), msg.Message) + subs := p.getSubscriptions(msg) // call before goroutine! + go func() { + if p.validate(subs, msg) { + p.sendMsg <- sendReq{ + from: p.host.ID(), + msg: msg, + } + + } + }() + case req := <-p.sendMsg: + p.maybePublishMessage(req.from, req.msg.Message) + case <-ctx.Done(): log.Info("pubsub processloop shutting down") return @@ -210,24 +226,22 @@ func (p *PubSub) handleRemoveSubscription(sub *Subscription) { // subscribes to the topic. // Only called from processLoop. func (p *PubSub) handleAddSubscription(req *addSubReq) { - subs := p.myTopics[req.topic] + sub := req.sub + subs := p.myTopics[sub.topic] // announce we want this topic if len(subs) == 0 { - p.announce(req.topic, true) + p.announce(sub.topic, true) } // make new if not there if subs == nil { - p.myTopics[req.topic] = make(map[*Subscription]struct{}) - subs = p.myTopics[req.topic] + p.myTopics[sub.topic] = make(map[*Subscription]struct{}) + subs = p.myTopics[sub.topic] } - sub := &Subscription{ - ch: make(chan *Message, 32), - topic: req.topic, - cancelCh: p.cancelCh, - } + sub.ch = make(chan *Message, 32) + sub.cancelCh = p.cancelCh p.myTopics[sub.topic][sub] = struct{}{} @@ -314,7 +328,15 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error { continue } - p.maybePublishMessage(rpc.from, pmsg) + subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine! + go func() { + if p.validate(subs, &Message{pmsg}) { + p.sendMsg <- sendReq{ + from: rpc.from, + msg: &*Message{pmsg}, + } + } + }() } return nil } @@ -324,6 +346,17 @@ func msgID(pmsg *pb.Message) string { return string(pmsg.GetFrom()) + string(pmsg.GetSeqno()) } +// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter. +func (p *PubSub) validate(subs []*Subscription, msg *Message) bool { + for _, sub := range subs { + if sub.validate != nil && !sub.validate(msg) { + return false + } + } + + return true +} + func (p *PubSub) maybePublishMessage(from peer.ID, pmsg *pb.Message) { id := msgID(pmsg) if p.seenMessage(id) { @@ -375,20 +408,47 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error { return nil } +// getSubscriptions returns all subscriptions the would receive the given message. +func (p *PubSub) getSubscriptions(msg *Message) []*Subscription { + var subs []*Subscription + + for _, topic := range msg.GetTopicIDs() { + tSubs, ok := p.myTopics[topic] + if !ok { + continue + } + + for sub := range tSubs { + subs = append(subs, sub) + } + } + + return subs +} + type addSubReq struct { - topic string - resp chan *Subscription + sub *Subscription + resp chan *Subscription +} + +// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further. +func WithValidator(validate func(*Message) bool) func(*Subscription) error { + return func(sub *Subscription) error { + sub.validate = validate + return nil + } + } // Subscribe returns a new Subscription for the given topic -func (p *PubSub) Subscribe(topic string) (*Subscription, error) { +func (p *PubSub) Subscribe(topic string, opts ...func(*Subscription) error) (*Subscription, error) { td := pb.TopicDescriptor{Name: &topic} - return p.SubscribeByTopicDescriptor(&td) + return p.SubscribeByTopicDescriptor(&td, opts...) } // SubscribeByTopicDescriptor lets you subscribe a topic using a pb.TopicDescriptor -func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscription, error) { +func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...func(*Subscription) error) (*Subscription, error) { if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE { return nil, fmt.Errorf("auth mode not yet supported") } @@ -397,10 +457,21 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscripti return nil, fmt.Errorf("encryption mode not yet supported") } + sub := &Subscription{ + topic: td.GetName(), + } + + for _, opt := range opts { + err := opt(sub) + if err != nil { + return nil, err + } + } + out := make(chan *Subscription, 1) p.addSub <- &addSubReq{ - topic: td.GetName(), - resp: out, + sub: sub, + resp: out, } return <-out, nil @@ -439,6 +510,12 @@ type listPeerReq struct { topic string } +// sendReq is a request to call maybePublishMessage. It is issued after the subscription verification is done. +type sendReq struct { + from peer.ID + msg *Message +} + // ListPeers returns a list of peers we are connected to. func (p *PubSub) ListPeers(topic string) []peer.ID { out := make(chan []peer.ID) diff --git a/floodsub_test.go b/floodsub_test.go index 51ee0b8..d520680 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -323,14 +323,53 @@ func TestOneToOne(t *testing.T) { connect(t, hosts[0], hosts[1]) - ch, err := psubs[1].Subscribe("foobar") + sub, err := psubs[1].Subscribe("foobar") if err != nil { t.Fatal(err) } time.Sleep(time.Millisecond * 50) - checkMessageRouting(t, "foobar", psubs, []*Subscription{ch}) + checkMessageRouting(t, "foobar", psubs, []*Subscription{sub}) +} + +func TestValidate(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + topic := "foobar" + + sub, err := psubs[1].Subscribe(topic, WithValidator(func(msg *Message) bool { + return !bytes.Contains(msg.Data, []byte("illegal")) + })) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 50) + + data := make([]byte, 16) + rand.Read(data) + + data = append(data, []byte("illegal")...) + + for _, p := range psubs { + err := p.Publish(topic, data) + if err != nil { + t.Fatal(err) + } + + select { + case msg := <-sub.ch: + t.Log(msg) + t.Fatal("expected message validation to filter out the message") + case <-time.After(333 * time.Millisecond): + } + } } func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) { diff --git a/subscription.go b/subscription.go index d6e930c..d7a7b35 100644 --- a/subscription.go +++ b/subscription.go @@ -9,6 +9,7 @@ type Subscription struct { ch chan *Message cancelCh chan<- *Subscription err error + validate func(*Message) bool } func (sub *Subscription) Topic() string {