From 02877cda7199f6733a6c1331c99207764f4e5342 Mon Sep 17 00:00:00 2001 From: keks Date: Wed, 22 Nov 2017 18:40:45 +0100 Subject: [PATCH] complete validator functions - make validators time out after 100ms - add context param to validator functions - add type Validator func(context.Context, *Message) bool - drop message if more than 10 messages are already being validated --- floodsub.go | 95 +++++++++++++++++++++++------------ floodsub_test.go | 128 ++++++++++++++++++++++++++++++++++++++++++++++- subscription.go | 2 +- 3 files changed, 190 insertions(+), 35 deletions(-) diff --git a/floodsub.go b/floodsub.go index 3eec281..802393d 100644 --- a/floodsub.go +++ b/floodsub.go @@ -17,7 +17,11 @@ import ( timecache "github.com/whyrusleeping/timecache" ) -const ID = protocol.ID("/floodsub/1.0.0") +const ( + ID = protocol.ID("/floodsub/1.0.0") + maxConcurrency = 10 + validateTimeoutMillis = 100 +) var log = logging.Logger("floodsub") @@ -57,6 +61,9 @@ type PubSub struct { // sendMsg handles messages that have been validated sendMsg chan sendReq + // throttleValidate bounds the number of goroutines concurrently validating messages + throttleValidate chan struct{} + peers map[peer.ID]chan *RPC seenMessages *timecache.TimeCache @@ -84,22 +91,23 @@ type RPC struct { // NewFloodSub returns a new FloodSub management object func NewFloodSub(ctx context.Context, h host.Host) *PubSub { ps := &PubSub{ - host: h, - ctx: ctx, - incoming: make(chan *RPC, 32), - publish: make(chan *Message), - newPeers: make(chan inet.Stream), - peerDead: make(chan peer.ID), - cancelCh: make(chan *Subscription), - getPeers: make(chan *listPeerReq), - addSub: make(chan *addSubReq), - getTopics: make(chan *topicReq), - sendMsg: make(chan sendReq), - myTopics: make(map[string]map[*Subscription]struct{}), - topics: make(map[string]map[peer.ID]struct{}), - peers: make(map[peer.ID]chan *RPC), - seenMessages: timecache.NewTimeCache(time.Second * 30), - counter: uint64(time.Now().UnixNano()), + host: h, + ctx: ctx, + incoming: make(chan *RPC, 32), + publish: make(chan *Message), + newPeers: make(chan inet.Stream), + peerDead: make(chan peer.ID), + cancelCh: make(chan *Subscription), + getPeers: make(chan *listPeerReq), + addSub: make(chan *addSubReq), + getTopics: make(chan *topicReq), + sendMsg: make(chan sendReq), + myTopics: make(map[string]map[*Subscription]struct{}), + topics: make(map[string]map[peer.ID]struct{}), + peers: make(map[peer.ID]chan *RPC), + seenMessages: timecache.NewTimeCache(time.Second * 30), + counter: uint64(time.Now().UnixNano()), + throttleValidate: make(chan struct{}, maxConcurrency), } h.SetStreamHandler(ID, ps.handleNewStream) @@ -181,14 +189,23 @@ func (p *PubSub) processLoop(ctx context.Context) { } case msg := <-p.publish: subs := p.getSubscriptions(msg) // call before goroutine! - go func() { - if p.validate(subs, msg) { - p.sendMsg <- sendReq{ - from: p.host.ID(), - msg: msg, + + select { + case p.throttleValidate <- struct{}{}: + go func() { + defer func() { <-p.throttleValidate }() + + if p.validate(subs, msg) { + p.sendMsg <- sendReq{ + from: p.host.ID(), + msg: msg, + } + } - } - }() + }() + default: + log.Warning("could not acquire validator; dropping message") + } case req := <-p.sendMsg: p.maybePublishMessage(req.from, req.msg.Message) @@ -328,14 +345,22 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error { } subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine! - go func(pmsg *pb.Message) { - if p.validate(subs, &Message{pmsg}) { - p.sendMsg <- sendReq{ - from: rpc.from, - msg: &Message{pmsg}, + + select { + case p.throttleValidate <- struct{}{}: + go func(pmsg *pb.Message) { + defer func() { <-p.throttleValidate }() + + if p.validate(subs, &Message{pmsg}) { + p.sendMsg <- sendReq{ + from: rpc.from, + msg: &Message{pmsg}, + } } - } - }(pmsg) + }(pmsg) + default: + log.Warning("could not acquire validator; dropping message") + } } return nil } @@ -348,7 +373,10 @@ func msgID(pmsg *pb.Message) string { // validate is called in a goroutine and calls the validate functions of all subs with msg as parameter. func (p *PubSub) validate(subs []*Subscription, msg *Message) bool { for _, sub := range subs { - if sub.validate != nil && !sub.validate(msg) { + ctx, cancel := context.WithTimeout(p.ctx, validateTimeoutMillis*time.Millisecond) + defer cancel() + + if sub.validate != nil && !sub.validate(ctx, msg) { log.Debugf("validator for topic %s returned false", sub.topic) return false } @@ -432,9 +460,10 @@ type addSubReq struct { } type SubOpt func(*Subscription) error +type Validator func(context.Context, *Message) bool // WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further. -func WithValidator(validate func(*Message) bool) func(*Subscription) error { +func WithValidator(validate Validator) func(*Subscription) error { return func(sub *Subscription) error { sub.validate = validate return nil diff --git a/floodsub_test.go b/floodsub_test.go index 6393f93..fb1b2b9 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -6,6 +6,7 @@ import ( "fmt" "math/rand" "sort" + "sync" "testing" "time" @@ -343,7 +344,7 @@ func TestValidate(t *testing.T) { connect(t, hosts[0], hosts[1]) topic := "foobar" - sub, err := psubs[1].Subscribe(topic, WithValidator(func(msg *Message) bool { + sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { return !bytes.Contains(msg.Data, []byte("illegal")) })) if err != nil { @@ -384,6 +385,131 @@ func TestValidate(t *testing.T) { } } +func TestValidateCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + topic := "foobar" + + sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { + <-ctx.Done() + return true + })) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 50) + + testmsg := []byte("this is a legal message") + validates := true + + p := psubs[0] + + err = p.Publish(topic, testmsg) + if err != nil { + t.Fatal(err) + } + + select { + case msg := <-sub.ch: + if !validates { + t.Log(msg) + t.Error("expected message validation to filter out the message") + } + case <-time.After(333 * time.Millisecond): + if validates { + t.Error("expected message validation to accept the message") + } + } +} + +func TestValidateOverload(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + topic := "foobar" + + block := make(chan struct{}) + + sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { + _, _ = <-block + return true + })) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 50) + + msgs := []struct { + msg []byte + validates bool + }{ + {msg: []byte("this is a legal message"), validates: true}, + {msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true}, + {msg: []byte("there also is nothing controversial about this message"), validates: true}, + {msg: []byte("also fine"), validates: true}, + {msg: []byte("still, all good"), validates: true}, + {msg: []byte("this is getting boring"), validates: true}, + {msg: []byte("foo"), validates: true}, + {msg: []byte("foobar"), validates: true}, + {msg: []byte("foofoo"), validates: true}, + {msg: []byte("barfoo"), validates: true}, + {msg: []byte("barbar"), validates: false}, + } + + if len(msgs) != maxConcurrency+1 { + t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(msgs), maxConcurrency+1) + } + + p := psubs[0] + + var wg sync.WaitGroup + wg.Add(1) + go func() { + for _, tc := range msgs { + select { + case msg := <-sub.ch: + if !tc.validates { + t.Log(msg) + t.Error("expected message validation to drop the message because all validator goroutines are taken") + } + case <-time.After(333 * time.Millisecond): + if tc.validates { + t.Error("expected message validation to accept the message") + } + } + } + wg.Done() + }() + + for i, tc := range msgs { + err := p.Publish(topic, tc.msg) + if err != nil { + t.Fatal(err) + } + + // wait a bit to let pubsub's internal state machine start validating the message + time.Sleep(10 * time.Millisecond) + + // unblock validator goroutines after we sent one too many + if i == len(msgs)-1 { + close(block) + } + } + + wg.Wait() +} + func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) { peers := ps.ListPeers("") set := make(map[peer.ID]struct{}) diff --git a/subscription.go b/subscription.go index d7a7b35..a3f97dd 100644 --- a/subscription.go +++ b/subscription.go @@ -9,7 +9,7 @@ type Subscription struct { ch chan *Message cancelCh chan<- *Subscription err error - validate func(*Message) bool + validate Validator } func (sub *Subscription) Topic() string {