From 647bb9865583b0557f4a6851ccf7ae3910cb94c4 Mon Sep 17 00:00:00 2001 From: keks Date: Wed, 8 Nov 2017 20:00:52 +0100 Subject: [PATCH 01/27] optionally allow caller to validate messages --- floodsub.go | 113 +++++++++++++++++++++++++++++++++++++++-------- floodsub_test.go | 43 +++++++++++++++++- subscription.go | 1 + 3 files changed, 137 insertions(+), 20 deletions(-) diff --git a/floodsub.go b/floodsub.go index 52a93b9..8bfa79c 100644 --- a/floodsub.go +++ b/floodsub.go @@ -54,6 +54,9 @@ type PubSub struct { // topics tracks which topics each of our peers are subscribed to topics map[string]map[peer.ID]struct{} + // sendMsg handles messages that have been validated + sendMsg chan sendReq + peers map[peer.ID]chan *RPC seenMessages *timecache.TimeCache @@ -91,6 +94,7 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub { getPeers: make(chan *listPeerReq), addSub: make(chan *addSubReq), getTopics: make(chan *topicReq), + sendMsg: make(chan sendReq), myTopics: make(map[string]map[*Subscription]struct{}), topics: make(map[string]map[peer.ID]struct{}), peers: make(map[peer.ID]chan *RPC), @@ -176,7 +180,19 @@ func (p *PubSub) processLoop(ctx context.Context) { continue } case msg := <-p.publish: - p.maybePublishMessage(p.host.ID(), msg.Message) + subs := p.getSubscriptions(msg) // call before goroutine! + go func() { + if p.validate(subs, msg) { + p.sendMsg <- sendReq{ + from: p.host.ID(), + msg: msg, + } + + } + }() + case req := <-p.sendMsg: + p.maybePublishMessage(req.from, req.msg.Message) + case <-ctx.Done(): log.Info("pubsub processloop shutting down") return @@ -210,24 +226,22 @@ func (p *PubSub) handleRemoveSubscription(sub *Subscription) { // subscribes to the topic. // Only called from processLoop. func (p *PubSub) handleAddSubscription(req *addSubReq) { - subs := p.myTopics[req.topic] + sub := req.sub + subs := p.myTopics[sub.topic] // announce we want this topic if len(subs) == 0 { - p.announce(req.topic, true) + p.announce(sub.topic, true) } // make new if not there if subs == nil { - p.myTopics[req.topic] = make(map[*Subscription]struct{}) - subs = p.myTopics[req.topic] + p.myTopics[sub.topic] = make(map[*Subscription]struct{}) + subs = p.myTopics[sub.topic] } - sub := &Subscription{ - ch: make(chan *Message, 32), - topic: req.topic, - cancelCh: p.cancelCh, - } + sub.ch = make(chan *Message, 32) + sub.cancelCh = p.cancelCh p.myTopics[sub.topic][sub] = struct{}{} @@ -314,7 +328,15 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error { continue } - p.maybePublishMessage(rpc.from, pmsg) + subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine! + go func() { + if p.validate(subs, &Message{pmsg}) { + p.sendMsg <- sendReq{ + from: rpc.from, + msg: &*Message{pmsg}, + } + } + }() } return nil } @@ -324,6 +346,17 @@ func msgID(pmsg *pb.Message) string { return string(pmsg.GetFrom()) + string(pmsg.GetSeqno()) } +// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter. +func (p *PubSub) validate(subs []*Subscription, msg *Message) bool { + for _, sub := range subs { + if sub.validate != nil && !sub.validate(msg) { + return false + } + } + + return true +} + func (p *PubSub) maybePublishMessage(from peer.ID, pmsg *pb.Message) { id := msgID(pmsg) if p.seenMessage(id) { @@ -375,20 +408,47 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error { return nil } +// getSubscriptions returns all subscriptions the would receive the given message. +func (p *PubSub) getSubscriptions(msg *Message) []*Subscription { + var subs []*Subscription + + for _, topic := range msg.GetTopicIDs() { + tSubs, ok := p.myTopics[topic] + if !ok { + continue + } + + for sub := range tSubs { + subs = append(subs, sub) + } + } + + return subs +} + type addSubReq struct { - topic string - resp chan *Subscription + sub *Subscription + resp chan *Subscription +} + +// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further. +func WithValidator(validate func(*Message) bool) func(*Subscription) error { + return func(sub *Subscription) error { + sub.validate = validate + return nil + } + } // Subscribe returns a new Subscription for the given topic -func (p *PubSub) Subscribe(topic string) (*Subscription, error) { +func (p *PubSub) Subscribe(topic string, opts ...func(*Subscription) error) (*Subscription, error) { td := pb.TopicDescriptor{Name: &topic} - return p.SubscribeByTopicDescriptor(&td) + return p.SubscribeByTopicDescriptor(&td, opts...) } // SubscribeByTopicDescriptor lets you subscribe a topic using a pb.TopicDescriptor -func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscription, error) { +func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...func(*Subscription) error) (*Subscription, error) { if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE { return nil, fmt.Errorf("auth mode not yet supported") } @@ -397,10 +457,21 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscripti return nil, fmt.Errorf("encryption mode not yet supported") } + sub := &Subscription{ + topic: td.GetName(), + } + + for _, opt := range opts { + err := opt(sub) + if err != nil { + return nil, err + } + } + out := make(chan *Subscription, 1) p.addSub <- &addSubReq{ - topic: td.GetName(), - resp: out, + sub: sub, + resp: out, } return <-out, nil @@ -439,6 +510,12 @@ type listPeerReq struct { topic string } +// sendReq is a request to call maybePublishMessage. It is issued after the subscription verification is done. +type sendReq struct { + from peer.ID + msg *Message +} + // ListPeers returns a list of peers we are connected to. func (p *PubSub) ListPeers(topic string) []peer.ID { out := make(chan []peer.ID) diff --git a/floodsub_test.go b/floodsub_test.go index 51ee0b8..d520680 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -323,14 +323,53 @@ func TestOneToOne(t *testing.T) { connect(t, hosts[0], hosts[1]) - ch, err := psubs[1].Subscribe("foobar") + sub, err := psubs[1].Subscribe("foobar") if err != nil { t.Fatal(err) } time.Sleep(time.Millisecond * 50) - checkMessageRouting(t, "foobar", psubs, []*Subscription{ch}) + checkMessageRouting(t, "foobar", psubs, []*Subscription{sub}) +} + +func TestValidate(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + topic := "foobar" + + sub, err := psubs[1].Subscribe(topic, WithValidator(func(msg *Message) bool { + return !bytes.Contains(msg.Data, []byte("illegal")) + })) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 50) + + data := make([]byte, 16) + rand.Read(data) + + data = append(data, []byte("illegal")...) + + for _, p := range psubs { + err := p.Publish(topic, data) + if err != nil { + t.Fatal(err) + } + + select { + case msg := <-sub.ch: + t.Log(msg) + t.Fatal("expected message validation to filter out the message") + case <-time.After(333 * time.Millisecond): + } + } } func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) { diff --git a/subscription.go b/subscription.go index d6e930c..d7a7b35 100644 --- a/subscription.go +++ b/subscription.go @@ -9,6 +9,7 @@ type Subscription struct { ch chan *Message cancelCh chan<- *Subscription err error + validate func(*Message) bool } func (sub *Subscription) Topic() string { From 930f264a271c5089cdad7e7fec368815f3d51656 Mon Sep 17 00:00:00 2001 From: keks Date: Thu, 16 Nov 2017 11:48:13 +0100 Subject: [PATCH 02/27] typedef subscription options and fix typo --- floodsub.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/floodsub.go b/floodsub.go index 8bfa79c..ab41cee 100644 --- a/floodsub.go +++ b/floodsub.go @@ -333,7 +333,7 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error { if p.validate(subs, &Message{pmsg}) { p.sendMsg <- sendReq{ from: rpc.from, - msg: &*Message{pmsg}, + msg: &Message{pmsg}, } } }() @@ -431,6 +431,8 @@ type addSubReq struct { resp chan *Subscription } +type SubOpt func(*Subscription) error + // WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further. func WithValidator(validate func(*Message) bool) func(*Subscription) error { return func(sub *Subscription) error { @@ -441,14 +443,14 @@ func WithValidator(validate func(*Message) bool) func(*Subscription) error { } // Subscribe returns a new Subscription for the given topic -func (p *PubSub) Subscribe(topic string, opts ...func(*Subscription) error) (*Subscription, error) { +func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) { td := pb.TopicDescriptor{Name: &topic} return p.SubscribeByTopicDescriptor(&td, opts...) } // SubscribeByTopicDescriptor lets you subscribe a topic using a pb.TopicDescriptor -func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...func(*Subscription) error) (*Subscription, error) { +func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubOpt) (*Subscription, error) { if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE { return nil, fmt.Errorf("auth mode not yet supported") } From 1945f895a2307a43ceb7641d1c1d464c7e85812a Mon Sep 17 00:00:00 2001 From: keks Date: Thu, 16 Nov 2017 13:03:33 +0100 Subject: [PATCH 03/27] log when validator discards message --- floodsub.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/floodsub.go b/floodsub.go index ab41cee..7a22d8d 100644 --- a/floodsub.go +++ b/floodsub.go @@ -187,7 +187,6 @@ func (p *PubSub) processLoop(ctx context.Context) { from: p.host.ID(), msg: msg, } - } }() case req := <-p.sendMsg: @@ -350,6 +349,7 @@ func msgID(pmsg *pb.Message) string { func (p *PubSub) validate(subs []*Subscription, msg *Message) bool { for _, sub := range subs { if sub.validate != nil && !sub.validate(msg) { + log.Debugf("validator for topic %s returned false", sub.topic) return false } } From 7dd4e0bfebe440031880793366c24edbea129305 Mon Sep 17 00:00:00 2001 From: keks Date: Thu, 16 Nov 2017 14:21:21 +0100 Subject: [PATCH 04/27] vet used for range variable inside goroutine, now passed as argument --- comm.go | 6 +++--- floodsub.go | 16 ++++++++-------- floodsub_test.go | 8 ++++---- notify.go | 4 ++-- pb/rpc.pb.go | 2 +- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/comm.go b/comm.go index adda4d4..cde23c5 100644 --- a/comm.go +++ b/comm.go @@ -7,9 +7,9 @@ import ( pb "github.com/libp2p/go-floodsub/pb" - ggio "github.com/gogo/protobuf/io" - proto "github.com/gogo/protobuf/proto" - inet "github.com/libp2p/go-libp2p-net" + inet "gx/ipfs/QmNa31VPzC561NWwRsJLE7nGYZYuuD2QfpK2b1q9BK54J1/go-libp2p-net" + ggio "gx/ipfs/QmZ4Qi3GaRbjcx28Sme5eMH7RQjGkt8wHxt2a65oLaeFEV/gogo-protobuf/io" + proto "gx/ipfs/QmZ4Qi3GaRbjcx28Sme5eMH7RQjGkt8wHxt2a65oLaeFEV/gogo-protobuf/proto" ) // get the initial RPC containing all of our subscriptions to send to new peers diff --git a/floodsub.go b/floodsub.go index 7a22d8d..244c68f 100644 --- a/floodsub.go +++ b/floodsub.go @@ -9,12 +9,12 @@ import ( pb "github.com/libp2p/go-floodsub/pb" - logging "github.com/ipfs/go-log" - host "github.com/libp2p/go-libp2p-host" - inet "github.com/libp2p/go-libp2p-net" - peer "github.com/libp2p/go-libp2p-peer" - protocol "github.com/libp2p/go-libp2p-protocol" - timecache "github.com/whyrusleeping/timecache" + inet "gx/ipfs/QmNa31VPzC561NWwRsJLE7nGYZYuuD2QfpK2b1q9BK54J1/go-libp2p-net" + logging "gx/ipfs/QmSpJByNKFX1sCsHBEp3R73FL4NF6FnQTEGyNAXHm2GS52/go-log" + peer "gx/ipfs/QmXYjuNuxVzXKJCfWasQk1RqkhVLDM9jtUKhqc2WPQmFSB/go-libp2p-peer" + timecache "gx/ipfs/QmYftoT56eEfUBTD3erR6heXuPSUhGRezSmhSU8LeczP8b/timecache" + protocol "gx/ipfs/QmZNkThpqfVXs9GNbexPrfBbXSLNYeKrE7jwFM2oqHbyqN/go-libp2p-protocol" + host "gx/ipfs/Qmc1XhrFEiSeBNn3mpfg6gEuYCt5im2gYmNVmncsvmpeAk/go-libp2p-host" ) const ID = protocol.ID("/floodsub/1.0.0") @@ -328,14 +328,14 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error { } subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine! - go func() { + go func(pmsg *pb.Message) { if p.validate(subs, &Message{pmsg}) { p.sendMsg <- sendReq{ from: rpc.from, msg: &Message{pmsg}, } } - }() + }(pmsg) } return nil } diff --git a/floodsub_test.go b/floodsub_test.go index d520680..c43ecd2 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -9,11 +9,11 @@ import ( "testing" "time" - host "github.com/libp2p/go-libp2p-host" - netutil "github.com/libp2p/go-libp2p-netutil" - peer "github.com/libp2p/go-libp2p-peer" + netutil "gx/ipfs/QmQGX417WoxKxDJeHqouMEmmH4G1RCENNSzkZYHrXy3Xb3/go-libp2p-netutil" + peer "gx/ipfs/QmXYjuNuxVzXKJCfWasQk1RqkhVLDM9jtUKhqc2WPQmFSB/go-libp2p-peer" + host "gx/ipfs/Qmc1XhrFEiSeBNn3mpfg6gEuYCt5im2gYmNVmncsvmpeAk/go-libp2p-host" //bhost "github.com/libp2p/go-libp2p/p2p/host/basic" - bhost "github.com/libp2p/go-libp2p-blankhost" + bhost "gx/ipfs/QmQkeGXc9ZuQ5upVFpd2EjKvgw9aVh1BbtAgvNGVcebmmX/go-libp2p-blankhost" ) func checkMessageRouting(t *testing.T, topic string, pubs []*PubSub, subs []*Subscription) { diff --git a/notify.go b/notify.go index 11cb4e5..19492f6 100644 --- a/notify.go +++ b/notify.go @@ -1,8 +1,8 @@ package floodsub import ( - inet "github.com/libp2p/go-libp2p-net" - ma "github.com/multiformats/go-multiaddr" + inet "gx/ipfs/QmNa31VPzC561NWwRsJLE7nGYZYuuD2QfpK2b1q9BK54J1/go-libp2p-net" + ma "gx/ipfs/QmXY77cVe7rVRQXZZQRioukUM7aRW3BTcAgJe12MCtb3Ji/go-multiaddr" ) var _ inet.Notifiee = (*PubSubNotif)(nil) diff --git a/pb/rpc.pb.go b/pb/rpc.pb.go index a5933c0..feab0c7 100644 --- a/pb/rpc.pb.go +++ b/pb/rpc.pb.go @@ -15,7 +15,7 @@ It has these top-level messages: */ package floodsub_pb -import proto "github.com/gogo/protobuf/proto" +import proto "gx/ipfs/QmZ4Qi3GaRbjcx28Sme5eMH7RQjGkt8wHxt2a65oLaeFEV/gogo-protobuf/proto" import fmt "fmt" import math "math" From 197a5982a4c1f20ba654b03c605edc4d0f03006c Mon Sep 17 00:00:00 2001 From: keks Date: Thu, 16 Nov 2017 16:13:31 +0100 Subject: [PATCH 05/27] ungxify --- comm.go | 6 +++--- floodsub.go | 12 ++++++------ floodsub_test.go | 8 ++++---- notify.go | 4 ++-- pb/rpc.pb.go | 2 +- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/comm.go b/comm.go index cde23c5..adda4d4 100644 --- a/comm.go +++ b/comm.go @@ -7,9 +7,9 @@ import ( pb "github.com/libp2p/go-floodsub/pb" - inet "gx/ipfs/QmNa31VPzC561NWwRsJLE7nGYZYuuD2QfpK2b1q9BK54J1/go-libp2p-net" - ggio "gx/ipfs/QmZ4Qi3GaRbjcx28Sme5eMH7RQjGkt8wHxt2a65oLaeFEV/gogo-protobuf/io" - proto "gx/ipfs/QmZ4Qi3GaRbjcx28Sme5eMH7RQjGkt8wHxt2a65oLaeFEV/gogo-protobuf/proto" + ggio "github.com/gogo/protobuf/io" + proto "github.com/gogo/protobuf/proto" + inet "github.com/libp2p/go-libp2p-net" ) // get the initial RPC containing all of our subscriptions to send to new peers diff --git a/floodsub.go b/floodsub.go index 244c68f..3eec281 100644 --- a/floodsub.go +++ b/floodsub.go @@ -9,12 +9,12 @@ import ( pb "github.com/libp2p/go-floodsub/pb" - inet "gx/ipfs/QmNa31VPzC561NWwRsJLE7nGYZYuuD2QfpK2b1q9BK54J1/go-libp2p-net" - logging "gx/ipfs/QmSpJByNKFX1sCsHBEp3R73FL4NF6FnQTEGyNAXHm2GS52/go-log" - peer "gx/ipfs/QmXYjuNuxVzXKJCfWasQk1RqkhVLDM9jtUKhqc2WPQmFSB/go-libp2p-peer" - timecache "gx/ipfs/QmYftoT56eEfUBTD3erR6heXuPSUhGRezSmhSU8LeczP8b/timecache" - protocol "gx/ipfs/QmZNkThpqfVXs9GNbexPrfBbXSLNYeKrE7jwFM2oqHbyqN/go-libp2p-protocol" - host "gx/ipfs/Qmc1XhrFEiSeBNn3mpfg6gEuYCt5im2gYmNVmncsvmpeAk/go-libp2p-host" + logging "github.com/ipfs/go-log" + host "github.com/libp2p/go-libp2p-host" + inet "github.com/libp2p/go-libp2p-net" + peer "github.com/libp2p/go-libp2p-peer" + protocol "github.com/libp2p/go-libp2p-protocol" + timecache "github.com/whyrusleeping/timecache" ) const ID = protocol.ID("/floodsub/1.0.0") diff --git a/floodsub_test.go b/floodsub_test.go index c43ecd2..d520680 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -9,11 +9,11 @@ import ( "testing" "time" - netutil "gx/ipfs/QmQGX417WoxKxDJeHqouMEmmH4G1RCENNSzkZYHrXy3Xb3/go-libp2p-netutil" - peer "gx/ipfs/QmXYjuNuxVzXKJCfWasQk1RqkhVLDM9jtUKhqc2WPQmFSB/go-libp2p-peer" - host "gx/ipfs/Qmc1XhrFEiSeBNn3mpfg6gEuYCt5im2gYmNVmncsvmpeAk/go-libp2p-host" + host "github.com/libp2p/go-libp2p-host" + netutil "github.com/libp2p/go-libp2p-netutil" + peer "github.com/libp2p/go-libp2p-peer" //bhost "github.com/libp2p/go-libp2p/p2p/host/basic" - bhost "gx/ipfs/QmQkeGXc9ZuQ5upVFpd2EjKvgw9aVh1BbtAgvNGVcebmmX/go-libp2p-blankhost" + bhost "github.com/libp2p/go-libp2p-blankhost" ) func checkMessageRouting(t *testing.T, topic string, pubs []*PubSub, subs []*Subscription) { diff --git a/notify.go b/notify.go index 19492f6..11cb4e5 100644 --- a/notify.go +++ b/notify.go @@ -1,8 +1,8 @@ package floodsub import ( - inet "gx/ipfs/QmNa31VPzC561NWwRsJLE7nGYZYuuD2QfpK2b1q9BK54J1/go-libp2p-net" - ma "gx/ipfs/QmXY77cVe7rVRQXZZQRioukUM7aRW3BTcAgJe12MCtb3Ji/go-multiaddr" + inet "github.com/libp2p/go-libp2p-net" + ma "github.com/multiformats/go-multiaddr" ) var _ inet.Notifiee = (*PubSubNotif)(nil) diff --git a/pb/rpc.pb.go b/pb/rpc.pb.go index feab0c7..a5933c0 100644 --- a/pb/rpc.pb.go +++ b/pb/rpc.pb.go @@ -15,7 +15,7 @@ It has these top-level messages: */ package floodsub_pb -import proto "gx/ipfs/QmZ4Qi3GaRbjcx28Sme5eMH7RQjGkt8wHxt2a65oLaeFEV/gogo-protobuf/proto" +import proto "github.com/gogo/protobuf/proto" import fmt "fmt" import math "math" From 89e6a06f3c79b47da1f1c0e35cea7e3144466f03 Mon Sep 17 00:00:00 2001 From: keks Date: Wed, 22 Nov 2017 17:37:01 +0100 Subject: [PATCH 06/27] better tests for validation --- floodsub_test.go | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/floodsub_test.go b/floodsub_test.go index d520680..6393f93 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -352,22 +352,34 @@ func TestValidate(t *testing.T) { time.Sleep(time.Millisecond * 50) - data := make([]byte, 16) - rand.Read(data) + msgs := []struct { + msg []byte + validates bool + }{ + {msg: []byte("this is a legal message"), validates: true}, + {msg: []byte("there also is nothing controversial about this message"), validates: true}, + {msg: []byte("openly illegal content will be censored"), validates: false}, + {msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true}, + } - data = append(data, []byte("illegal")...) + for _, tc := range msgs { + for _, p := range psubs { + err := p.Publish(topic, tc.msg) + if err != nil { + t.Fatal(err) + } - for _, p := range psubs { - err := p.Publish(topic, data) - if err != nil { - t.Fatal(err) - } - - select { - case msg := <-sub.ch: - t.Log(msg) - t.Fatal("expected message validation to filter out the message") - case <-time.After(333 * time.Millisecond): + select { + case msg := <-sub.ch: + if !tc.validates { + t.Log(msg) + t.Error("expected message validation to filter out the message") + } + case <-time.After(333 * time.Millisecond): + if tc.validates { + t.Error("expected message validation to accept the message") + } + } } } } From 02877cda7199f6733a6c1331c99207764f4e5342 Mon Sep 17 00:00:00 2001 From: keks Date: Wed, 22 Nov 2017 18:40:45 +0100 Subject: [PATCH 07/27] complete validator functions - make validators time out after 100ms - add context param to validator functions - add type Validator func(context.Context, *Message) bool - drop message if more than 10 messages are already being validated --- floodsub.go | 95 +++++++++++++++++++++++------------ floodsub_test.go | 128 ++++++++++++++++++++++++++++++++++++++++++++++- subscription.go | 2 +- 3 files changed, 190 insertions(+), 35 deletions(-) diff --git a/floodsub.go b/floodsub.go index 3eec281..802393d 100644 --- a/floodsub.go +++ b/floodsub.go @@ -17,7 +17,11 @@ import ( timecache "github.com/whyrusleeping/timecache" ) -const ID = protocol.ID("/floodsub/1.0.0") +const ( + ID = protocol.ID("/floodsub/1.0.0") + maxConcurrency = 10 + validateTimeoutMillis = 100 +) var log = logging.Logger("floodsub") @@ -57,6 +61,9 @@ type PubSub struct { // sendMsg handles messages that have been validated sendMsg chan sendReq + // throttleValidate bounds the number of goroutines concurrently validating messages + throttleValidate chan struct{} + peers map[peer.ID]chan *RPC seenMessages *timecache.TimeCache @@ -84,22 +91,23 @@ type RPC struct { // NewFloodSub returns a new FloodSub management object func NewFloodSub(ctx context.Context, h host.Host) *PubSub { ps := &PubSub{ - host: h, - ctx: ctx, - incoming: make(chan *RPC, 32), - publish: make(chan *Message), - newPeers: make(chan inet.Stream), - peerDead: make(chan peer.ID), - cancelCh: make(chan *Subscription), - getPeers: make(chan *listPeerReq), - addSub: make(chan *addSubReq), - getTopics: make(chan *topicReq), - sendMsg: make(chan sendReq), - myTopics: make(map[string]map[*Subscription]struct{}), - topics: make(map[string]map[peer.ID]struct{}), - peers: make(map[peer.ID]chan *RPC), - seenMessages: timecache.NewTimeCache(time.Second * 30), - counter: uint64(time.Now().UnixNano()), + host: h, + ctx: ctx, + incoming: make(chan *RPC, 32), + publish: make(chan *Message), + newPeers: make(chan inet.Stream), + peerDead: make(chan peer.ID), + cancelCh: make(chan *Subscription), + getPeers: make(chan *listPeerReq), + addSub: make(chan *addSubReq), + getTopics: make(chan *topicReq), + sendMsg: make(chan sendReq), + myTopics: make(map[string]map[*Subscription]struct{}), + topics: make(map[string]map[peer.ID]struct{}), + peers: make(map[peer.ID]chan *RPC), + seenMessages: timecache.NewTimeCache(time.Second * 30), + counter: uint64(time.Now().UnixNano()), + throttleValidate: make(chan struct{}, maxConcurrency), } h.SetStreamHandler(ID, ps.handleNewStream) @@ -181,14 +189,23 @@ func (p *PubSub) processLoop(ctx context.Context) { } case msg := <-p.publish: subs := p.getSubscriptions(msg) // call before goroutine! - go func() { - if p.validate(subs, msg) { - p.sendMsg <- sendReq{ - from: p.host.ID(), - msg: msg, + + select { + case p.throttleValidate <- struct{}{}: + go func() { + defer func() { <-p.throttleValidate }() + + if p.validate(subs, msg) { + p.sendMsg <- sendReq{ + from: p.host.ID(), + msg: msg, + } + } - } - }() + }() + default: + log.Warning("could not acquire validator; dropping message") + } case req := <-p.sendMsg: p.maybePublishMessage(req.from, req.msg.Message) @@ -328,14 +345,22 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error { } subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine! - go func(pmsg *pb.Message) { - if p.validate(subs, &Message{pmsg}) { - p.sendMsg <- sendReq{ - from: rpc.from, - msg: &Message{pmsg}, + + select { + case p.throttleValidate <- struct{}{}: + go func(pmsg *pb.Message) { + defer func() { <-p.throttleValidate }() + + if p.validate(subs, &Message{pmsg}) { + p.sendMsg <- sendReq{ + from: rpc.from, + msg: &Message{pmsg}, + } } - } - }(pmsg) + }(pmsg) + default: + log.Warning("could not acquire validator; dropping message") + } } return nil } @@ -348,7 +373,10 @@ func msgID(pmsg *pb.Message) string { // validate is called in a goroutine and calls the validate functions of all subs with msg as parameter. func (p *PubSub) validate(subs []*Subscription, msg *Message) bool { for _, sub := range subs { - if sub.validate != nil && !sub.validate(msg) { + ctx, cancel := context.WithTimeout(p.ctx, validateTimeoutMillis*time.Millisecond) + defer cancel() + + if sub.validate != nil && !sub.validate(ctx, msg) { log.Debugf("validator for topic %s returned false", sub.topic) return false } @@ -432,9 +460,10 @@ type addSubReq struct { } type SubOpt func(*Subscription) error +type Validator func(context.Context, *Message) bool // WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further. -func WithValidator(validate func(*Message) bool) func(*Subscription) error { +func WithValidator(validate Validator) func(*Subscription) error { return func(sub *Subscription) error { sub.validate = validate return nil diff --git a/floodsub_test.go b/floodsub_test.go index 6393f93..fb1b2b9 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -6,6 +6,7 @@ import ( "fmt" "math/rand" "sort" + "sync" "testing" "time" @@ -343,7 +344,7 @@ func TestValidate(t *testing.T) { connect(t, hosts[0], hosts[1]) topic := "foobar" - sub, err := psubs[1].Subscribe(topic, WithValidator(func(msg *Message) bool { + sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { return !bytes.Contains(msg.Data, []byte("illegal")) })) if err != nil { @@ -384,6 +385,131 @@ func TestValidate(t *testing.T) { } } +func TestValidateCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + topic := "foobar" + + sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { + <-ctx.Done() + return true + })) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 50) + + testmsg := []byte("this is a legal message") + validates := true + + p := psubs[0] + + err = p.Publish(topic, testmsg) + if err != nil { + t.Fatal(err) + } + + select { + case msg := <-sub.ch: + if !validates { + t.Log(msg) + t.Error("expected message validation to filter out the message") + } + case <-time.After(333 * time.Millisecond): + if validates { + t.Error("expected message validation to accept the message") + } + } +} + +func TestValidateOverload(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + topic := "foobar" + + block := make(chan struct{}) + + sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { + _, _ = <-block + return true + })) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 50) + + msgs := []struct { + msg []byte + validates bool + }{ + {msg: []byte("this is a legal message"), validates: true}, + {msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true}, + {msg: []byte("there also is nothing controversial about this message"), validates: true}, + {msg: []byte("also fine"), validates: true}, + {msg: []byte("still, all good"), validates: true}, + {msg: []byte("this is getting boring"), validates: true}, + {msg: []byte("foo"), validates: true}, + {msg: []byte("foobar"), validates: true}, + {msg: []byte("foofoo"), validates: true}, + {msg: []byte("barfoo"), validates: true}, + {msg: []byte("barbar"), validates: false}, + } + + if len(msgs) != maxConcurrency+1 { + t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(msgs), maxConcurrency+1) + } + + p := psubs[0] + + var wg sync.WaitGroup + wg.Add(1) + go func() { + for _, tc := range msgs { + select { + case msg := <-sub.ch: + if !tc.validates { + t.Log(msg) + t.Error("expected message validation to drop the message because all validator goroutines are taken") + } + case <-time.After(333 * time.Millisecond): + if tc.validates { + t.Error("expected message validation to accept the message") + } + } + } + wg.Done() + }() + + for i, tc := range msgs { + err := p.Publish(topic, tc.msg) + if err != nil { + t.Fatal(err) + } + + // wait a bit to let pubsub's internal state machine start validating the message + time.Sleep(10 * time.Millisecond) + + // unblock validator goroutines after we sent one too many + if i == len(msgs)-1 { + close(block) + } + } + + wg.Wait() +} + func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) { peers := ps.ListPeers("") set := make(map[peer.ID]struct{}) diff --git a/subscription.go b/subscription.go index d7a7b35..a3f97dd 100644 --- a/subscription.go +++ b/subscription.go @@ -9,7 +9,7 @@ type Subscription struct { ch chan *Message cancelCh chan<- *Subscription err error - validate func(*Message) bool + validate Validator } func (sub *Subscription) Topic() string { From 6e8b9f2d5c9e7490179a0702582ec008b92b8858 Mon Sep 17 00:00:00 2001 From: keks Date: Thu, 23 Nov 2017 14:39:14 +0100 Subject: [PATCH 08/27] fix timeout --- floodsub.go | 29 ++++++++++++++++++++--------- floodsub_test.go | 4 ++-- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/floodsub.go b/floodsub.go index 802393d..99441fd 100644 --- a/floodsub.go +++ b/floodsub.go @@ -18,9 +18,9 @@ import ( ) const ( - ID = protocol.ID("/floodsub/1.0.0") - maxConcurrency = 10 - validateTimeoutMillis = 100 + ID = protocol.ID("/floodsub/1.0.0") + maxConcurrency = 10 + validateTimeout = 150 * time.Millisecond ) var log = logging.Logger("floodsub") @@ -192,7 +192,7 @@ func (p *PubSub) processLoop(ctx context.Context) { select { case p.throttleValidate <- struct{}{}: - go func() { + go func(msg *Message) { defer func() { <-p.throttleValidate }() if p.validate(subs, msg) { @@ -202,7 +202,7 @@ func (p *PubSub) processLoop(ctx context.Context) { } } - }() + }(msg) default: log.Warning("could not acquire validator; dropping message") } @@ -373,11 +373,22 @@ func msgID(pmsg *pb.Message) string { // validate is called in a goroutine and calls the validate functions of all subs with msg as parameter. func (p *PubSub) validate(subs []*Subscription, msg *Message) bool { for _, sub := range subs { - ctx, cancel := context.WithTimeout(p.ctx, validateTimeoutMillis*time.Millisecond) + ctx, cancel := context.WithTimeout(p.ctx, validateTimeout) defer cancel() - if sub.validate != nil && !sub.validate(ctx, msg) { - log.Debugf("validator for topic %s returned false", sub.topic) + result := make(chan bool) + go func(sub *Subscription) { + result <- sub.validate == nil || sub.validate(ctx, msg) + }(sub) + + select { + case valid := <-result: + if !valid { + log.Debugf("validator for topic %s returned false", sub.topic) + return false + } + case <-ctx.Done(): + log.Debugf("validator for topic %s timed out. msg: %s", sub.topic, msg) return false } } @@ -409,7 +420,7 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error { continue } - for p, _ := range tmap { + for p := range tmap { tosend[p] = struct{}{} } } diff --git a/floodsub_test.go b/floodsub_test.go index fb1b2b9..57cf31e 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -406,7 +406,7 @@ func TestValidateCancel(t *testing.T) { time.Sleep(time.Millisecond * 50) testmsg := []byte("this is a legal message") - validates := true + validates := false // message for which the validator times our are discarded p := psubs[0] @@ -441,7 +441,7 @@ func TestValidateOverload(t *testing.T) { block := make(chan struct{}) sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { - _, _ = <-block + <-block return true })) if err != nil { From fe09d1eea3dfa2f469acf4c21de2e57a0688e79c Mon Sep 17 00:00:00 2001 From: keks Date: Thu, 23 Nov 2017 19:12:59 +0100 Subject: [PATCH 09/27] make validator timeout configurable --- floodsub.go | 18 +++++++++++----- floodsub_test.go | 55 ++++++++++++++++++++++++++++++++++++++++++++++++ subscription.go | 5 ++++- 3 files changed, 72 insertions(+), 6 deletions(-) diff --git a/floodsub.go b/floodsub.go index 99441fd..f7c2eb0 100644 --- a/floodsub.go +++ b/floodsub.go @@ -18,9 +18,9 @@ import ( ) const ( - ID = protocol.ID("/floodsub/1.0.0") - maxConcurrency = 10 - validateTimeout = 150 * time.Millisecond + ID = protocol.ID("/floodsub/1.0.0") + maxConcurrency = 10 + defaultValidateTimeout = 150 * time.Millisecond ) var log = logging.Logger("floodsub") @@ -373,7 +373,7 @@ func msgID(pmsg *pb.Message) string { // validate is called in a goroutine and calls the validate functions of all subs with msg as parameter. func (p *PubSub) validate(subs []*Subscription, msg *Message) bool { for _, sub := range subs { - ctx, cancel := context.WithTimeout(p.ctx, validateTimeout) + ctx, cancel := context.WithTimeout(p.ctx, sub.validateTimeout) defer cancel() result := make(chan bool) @@ -479,7 +479,14 @@ func WithValidator(validate Validator) func(*Subscription) error { sub.validate = validate return nil } +} +// WithValidatorTimeout is an option that can be supplied to Subscribe. The argument is a duration after which long-running validators are canceled. +func WithValidatorTimeout(timeout time.Duration) func(*Subscription) error { + return func(sub *Subscription) error { + sub.validateTimeout = timeout + return nil + } } // Subscribe returns a new Subscription for the given topic @@ -500,7 +507,8 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO } sub := &Subscription{ - topic: td.GetName(), + topic: td.GetName(), + validateTimeout: defaultValidateTimeout, } for _, opt := range opts { diff --git a/floodsub_test.go b/floodsub_test.go index 57cf31e..35a2400 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -385,6 +385,61 @@ func TestValidate(t *testing.T) { } } +func TestValidateTimeout(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + topic := "foobar" + + cases := []struct { + timeout time.Duration + msg []byte + validates bool + }{ + {75 * time.Millisecond, []byte("this better time out"), false}, + {150 * time.Millisecond, []byte("this should work"), true}, + } + + for _, tc := range cases { + sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { + time.Sleep(100 * time.Millisecond) + return true + }), WithValidatorTimeout(tc.timeout)) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 50) + + p := psubs[0] + err = p.Publish(topic, tc.msg) + if err != nil { + t.Fatal(err) + } + + select { + case msg := <-sub.ch: + if !tc.validates { + t.Log(msg) + t.Error("expected message validation to filter out the message") + } + case <-time.After(333 * time.Millisecond): + if tc.validates { + t.Error("expected message validation to accept the message") + } + } + + // important: cancel! + // otherwise the message will still be filtered by the other subscription + sub.Cancel() + } + +} + func TestValidateCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/subscription.go b/subscription.go index a3f97dd..cc76413 100644 --- a/subscription.go +++ b/subscription.go @@ -2,6 +2,7 @@ package floodsub import ( "context" + "time" ) type Subscription struct { @@ -9,7 +10,9 @@ type Subscription struct { ch chan *Message cancelCh chan<- *Subscription err error - validate Validator + + validate Validator + validateTimeout time.Duration } func (sub *Subscription) Topic() string { From 88274db0bba733a0e8067d7e4933e6a4b1fbfe53 Mon Sep 17 00:00:00 2001 From: keks Date: Sat, 16 Dec 2017 13:12:23 +0100 Subject: [PATCH 10/27] make maximum concurrency configurable, split loop --- floodsub.go | 37 +++++++++-- floodsub_test.go | 163 ++++++++++++++++++++++++++++------------------- 2 files changed, 129 insertions(+), 71 deletions(-) diff --git a/floodsub.go b/floodsub.go index f7c2eb0..70a207c 100644 --- a/floodsub.go +++ b/floodsub.go @@ -19,7 +19,7 @@ import ( const ( ID = protocol.ID("/floodsub/1.0.0") - maxConcurrency = 10 + defaultMaxConcurrency = 10 defaultValidateTimeout = 150 * time.Millisecond ) @@ -88,8 +88,17 @@ type RPC struct { from peer.ID } +type Option func(*PubSub) error + +func WithMaxConcurrency(n int) Option { + return func(ps *PubSub) error { + ps.throttleValidate = make(chan struct{}, n) + return nil + } +} + // NewFloodSub returns a new FloodSub management object -func NewFloodSub(ctx context.Context, h host.Host) *PubSub { +func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) { ps := &PubSub{ host: h, ctx: ctx, @@ -110,12 +119,19 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub { throttleValidate: make(chan struct{}, maxConcurrency), } + for _, opt := range opts { + err := opt(ps) + if err != nil { + return nil, err + } + } + h.SetStreamHandler(ID, ps.handleNewStream) h.Network().Notify((*PubSubNotif)(ps)) go ps.processLoop(ctx) - return ps + return ps, nil } // processLoop handles all inputs arriving on the channels @@ -372,14 +388,25 @@ func msgID(pmsg *pb.Message) string { // validate is called in a goroutine and calls the validate functions of all subs with msg as parameter. func (p *PubSub) validate(subs []*Subscription, msg *Message) bool { - for _, sub := range subs { + results := make([]chan bool, len(subs)) + ctxs := make([]context.Context, len(subs)) + + for i, sub := range subs { + result := make(chan bool) ctx, cancel := context.WithTimeout(p.ctx, sub.validateTimeout) defer cancel() - result := make(chan bool) + ctxs[i] = ctx + results[i] = result + go func(sub *Subscription) { result <- sub.validate == nil || sub.validate(ctx, msg) }(sub) + } + + for i, sub := range subs { + ctx := ctxs[i] + result := results[i] select { case valid := <-result: diff --git a/floodsub_test.go b/floodsub_test.go index 35a2400..68357ab 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -81,10 +81,14 @@ func connectAll(t *testing.T, hosts []host.Host) { } } -func getPubsubs(ctx context.Context, hs []host.Host) []*PubSub { +func getPubsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSub { var psubs []*PubSub for _, h := range hs { - psubs = append(psubs, NewFloodSub(ctx, h)) + ps, err := NewFloodSub(ctx, h, opts...) + if err != nil { + panic(err) + } + psubs = append(psubs, ps) } return psubs } @@ -290,11 +294,14 @@ func TestSelfReceive(t *testing.T) { host := getNetHosts(t, ctx, 1)[0] - psub := NewFloodSub(ctx, host) + psub, err := NewFloodSub(ctx, host) + if err != nil { + t.Fatal(err) + } msg := []byte("hello world") - err := psub.Publish("foobar", msg) + err = psub.Publish("foobar", msg) if err != nil { t.Fatal(err) } @@ -487,82 +494,103 @@ func TestValidateOverload(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hosts := getNetHosts(t, ctx, 2) - psubs := getPubsubs(ctx, hosts) - - connect(t, hosts[0], hosts[1]) - topic := "foobar" - - block := make(chan struct{}) - - sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { - <-block - return true - })) - if err != nil { - t.Fatal(err) - } - - time.Sleep(time.Millisecond * 50) - - msgs := []struct { + type msg struct { msg []byte validates bool + } + + tcs := []struct { + msgs []msg + + maxConcurrency int }{ - {msg: []byte("this is a legal message"), validates: true}, - {msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true}, - {msg: []byte("there also is nothing controversial about this message"), validates: true}, - {msg: []byte("also fine"), validates: true}, - {msg: []byte("still, all good"), validates: true}, - {msg: []byte("this is getting boring"), validates: true}, - {msg: []byte("foo"), validates: true}, - {msg: []byte("foobar"), validates: true}, - {msg: []byte("foofoo"), validates: true}, - {msg: []byte("barfoo"), validates: true}, - {msg: []byte("barbar"), validates: false}, + { + maxConcurrency: 10, + msgs: []msg{ + {msg: []byte("this is a legal message"), validates: true}, + {msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true}, + {msg: []byte("there also is nothing controversial about this message"), validates: true}, + {msg: []byte("also fine"), validates: true}, + {msg: []byte("still, all good"), validates: true}, + {msg: []byte("this is getting boring"), validates: true}, + {msg: []byte("foo"), validates: true}, + {msg: []byte("foobar"), validates: true}, + {msg: []byte("foofoo"), validates: true}, + {msg: []byte("barfoo"), validates: true}, + {msg: []byte("oh no!"), validates: false}, + }, + }, + { + maxConcurrency: 2, + msgs: []msg{ + {msg: []byte("this is a legal message"), validates: true}, + {msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true}, + {msg: []byte("oh no!"), validates: false}, + }, + }, } - if len(msgs) != maxConcurrency+1 { - t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(msgs), maxConcurrency+1) - } + for _, tc := range tcs { - p := psubs[0] + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts, WithMaxConcurrency(tc.maxConcurrency)) - var wg sync.WaitGroup - wg.Add(1) - go func() { - for _, tc := range msgs { - select { - case msg := <-sub.ch: - if !tc.validates { - t.Log(msg) - t.Error("expected message validation to drop the message because all validator goroutines are taken") - } - case <-time.After(333 * time.Millisecond): - if tc.validates { - t.Error("expected message validation to accept the message") - } - } - } - wg.Done() - }() + connect(t, hosts[0], hosts[1]) + topic := "foobar" - for i, tc := range msgs { - err := p.Publish(topic, tc.msg) + block := make(chan struct{}) + + sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { + <-block + return true + })) if err != nil { t.Fatal(err) } - // wait a bit to let pubsub's internal state machine start validating the message - time.Sleep(10 * time.Millisecond) + time.Sleep(time.Millisecond * 50) - // unblock validator goroutines after we sent one too many - if i == len(msgs)-1 { - close(block) + if len(tc.msgs) != tc.maxConcurrency+1 { + t.Fatalf("expected number of messages sent to be defaultMaxConcurrency+1. Got %d, expected %d", len(tc.msgs), tc.maxConcurrency+1) } - } - wg.Wait() + p := psubs[0] + + var wg sync.WaitGroup + wg.Add(1) + go func() { + for _, tmsg := range tc.msgs { + select { + case msg := <-sub.ch: + if !tmsg.validates { + t.Log(msg) + t.Error("expected message validation to drop the message because all validator goroutines are taken") + } + case <-time.After(333 * time.Millisecond): + if tmsg.validates { + t.Error("expected message validation to accept the message") + } + } + } + wg.Done() + }() + + for i, tmsg := range tc.msgs { + err := p.Publish(topic, tmsg.msg) + if err != nil { + t.Fatal(err) + } + + // wait a bit to let pubsub's internal state machine start validating the message + time.Sleep(10 * time.Millisecond) + + // unblock validator goroutines after we sent one too many + if i == len(tc.msgs)-1 { + close(block) + } + } + wg.Wait() + } } func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) { @@ -646,7 +674,10 @@ func TestSubReporting(t *testing.T) { defer cancel() host := getNetHosts(t, ctx, 1)[0] - psub := NewFloodSub(ctx, host) + psub, err := NewFloodSub(ctx, host) + if err != nil { + t.Fatal(err) + } fooSub, err := psub.Subscribe("foo") if err != nil { From 4241241031cdad175ff8d7b0f6ef7502912b04af Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 12:24:31 +0200 Subject: [PATCH 11/27] fix dangling maxConcurrency reference --- floodsub.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/floodsub.go b/floodsub.go index 70a207c..1664418 100644 --- a/floodsub.go +++ b/floodsub.go @@ -116,7 +116,7 @@ func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, err peers: make(map[peer.ID]chan *RPC), seenMessages: timecache.NewTimeCache(time.Second * 30), counter: uint64(time.Now().UnixNano()), - throttleValidate: make(chan struct{}, maxConcurrency), + throttleValidate: make(chan struct{}, defaultMaxConcurrency), } for _, opt := range opts { From d2f6a0050fa1745d0e8ed583a9d138fd5b44186e Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 12:33:03 +0200 Subject: [PATCH 12/27] WithValidator and WithValidatorTimeout are subscription options --- floodsub.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/floodsub.go b/floodsub.go index 1664418..907fb47 100644 --- a/floodsub.go +++ b/floodsub.go @@ -501,7 +501,7 @@ type SubOpt func(*Subscription) error type Validator func(context.Context, *Message) bool // WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further. -func WithValidator(validate Validator) func(*Subscription) error { +func WithValidator(validate Validator) SubOpt { return func(sub *Subscription) error { sub.validate = validate return nil @@ -509,7 +509,7 @@ func WithValidator(validate Validator) func(*Subscription) error { } // WithValidatorTimeout is an option that can be supplied to Subscribe. The argument is a duration after which long-running validators are canceled. -func WithValidatorTimeout(timeout time.Duration) func(*Subscription) error { +func WithValidatorTimeout(timeout time.Duration) SubOpt { return func(sub *Subscription) error { sub.validateTimeout = timeout return nil From 982c4de960ceae9122c4f888abfac58de2f69752 Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 14:31:34 +0200 Subject: [PATCH 13/27] per subscription validation throttle and more efficient dispatch logic --- floodsub.go | 197 +++++++++++++++++++++++++---------------------- floodsub_test.go | 13 ++-- subscription.go | 26 ++++++- 3 files changed, 136 insertions(+), 100 deletions(-) diff --git a/floodsub.go b/floodsub.go index 907fb47..35a3cca 100644 --- a/floodsub.go +++ b/floodsub.go @@ -61,9 +61,6 @@ type PubSub struct { // sendMsg handles messages that have been validated sendMsg chan sendReq - // throttleValidate bounds the number of goroutines concurrently validating messages - throttleValidate chan struct{} - peers map[peer.ID]chan *RPC seenMessages *timecache.TimeCache @@ -90,33 +87,25 @@ type RPC struct { type Option func(*PubSub) error -func WithMaxConcurrency(n int) Option { - return func(ps *PubSub) error { - ps.throttleValidate = make(chan struct{}, n) - return nil - } -} - // NewFloodSub returns a new FloodSub management object func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) { ps := &PubSub{ - host: h, - ctx: ctx, - incoming: make(chan *RPC, 32), - publish: make(chan *Message), - newPeers: make(chan inet.Stream), - peerDead: make(chan peer.ID), - cancelCh: make(chan *Subscription), - getPeers: make(chan *listPeerReq), - addSub: make(chan *addSubReq), - getTopics: make(chan *topicReq), - sendMsg: make(chan sendReq), - myTopics: make(map[string]map[*Subscription]struct{}), - topics: make(map[string]map[peer.ID]struct{}), - peers: make(map[peer.ID]chan *RPC), - seenMessages: timecache.NewTimeCache(time.Second * 30), - counter: uint64(time.Now().UnixNano()), - throttleValidate: make(chan struct{}, defaultMaxConcurrency), + host: h, + ctx: ctx, + incoming: make(chan *RPC, 32), + publish: make(chan *Message), + newPeers: make(chan inet.Stream), + peerDead: make(chan peer.ID), + cancelCh: make(chan *Subscription), + getPeers: make(chan *listPeerReq), + addSub: make(chan *addSubReq), + getTopics: make(chan *topicReq), + sendMsg: make(chan sendReq), + myTopics: make(map[string]map[*Subscription]struct{}), + topics: make(map[string]map[peer.ID]struct{}), + peers: make(map[peer.ID]chan *RPC), + seenMessages: timecache.NewTimeCache(time.Second * 30), + counter: uint64(time.Now().UnixNano()), } for _, opt := range opts { @@ -204,24 +193,9 @@ func (p *PubSub) processLoop(ctx context.Context) { continue } case msg := <-p.publish: - subs := p.getSubscriptions(msg) // call before goroutine! + subs := p.getSubscriptions(msg) + p.pushMsg(subs, p.host.ID(), msg) - select { - case p.throttleValidate <- struct{}{}: - go func(msg *Message) { - defer func() { <-p.throttleValidate }() - - if p.validate(subs, msg) { - p.sendMsg <- sendReq{ - from: p.host.ID(), - msg: msg, - } - - } - }(msg) - default: - log.Warning("could not acquire validator; dropping message") - } case req := <-p.sendMsg: p.maybePublishMessage(req.from, req.msg.Message) @@ -360,24 +334,11 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error { continue } - subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine! - - select { - case p.throttleValidate <- struct{}{}: - go func(pmsg *pb.Message) { - defer func() { <-p.throttleValidate }() - - if p.validate(subs, &Message{pmsg}) { - p.sendMsg <- sendReq{ - from: rpc.from, - msg: &Message{pmsg}, - } - } - }(pmsg) - default: - log.Warning("could not acquire validator; dropping message") - } + msg := &Message{pmsg} + subs := p.getSubscriptions(msg) + p.pushMsg(subs, rpc.from, msg) } + return nil } @@ -386,41 +347,80 @@ func msgID(pmsg *pb.Message) string { return string(pmsg.GetFrom()) + string(pmsg.GetSeqno()) } -// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter. -func (p *PubSub) validate(subs []*Subscription, msg *Message) bool { - results := make([]chan bool, len(subs)) - ctxs := make([]context.Context, len(subs)) - - for i, sub := range subs { - result := make(chan bool) - ctx, cancel := context.WithTimeout(p.ctx, sub.validateTimeout) - defer cancel() - - ctxs[i] = ctx - results[i] = result - - go func(sub *Subscription) { - result <- sub.validate == nil || sub.validate(ctx, msg) - }(sub) - } - - for i, sub := range subs { - ctx := ctxs[i] - result := results[i] - - select { - case valid := <-result: - if !valid { - log.Debugf("validator for topic %s returned false", sub.topic) - return false - } - case <-ctx.Done(): - log.Debugf("validator for topic %s timed out. msg: %s", sub.topic, msg) - return false +// pushMsg pushes a message to a number of subscriptions, performing validation +// as necessary +func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { + // we perform validation if _any_ of the subscriptions has a validator + // because the message is sent once for all topics + needval := false + for _, sub := range subs { + if sub.validate != nil { + needval = true + break } } - return true + if !needval { + go func() { + p.sendMsg <- sendReq{ + from: src, + msg: msg, + } + }() + return + } + + // validation is asynchronous + // XXX vyzo: do we want a global validation throttle here? + go p.validate(subs, src, msg) +} + +// validate performs validation and only sends the message if all validators succeed +func (p *PubSub) validate(subs []*Subscription, src peer.ID, msg *Message) { + results := make([]chan bool, 0, len(subs)) + throttle := false + +loop: + for _, sub := range subs { + if sub.validate == nil { + continue + } + + rch := make(chan bool, 1) + results = append(results, rch) + + select { + case sub.validateThrottle <- struct{}{}: + go func(sub *Subscription, msg *Message, rch chan bool) { + rch <- sub.validateMsg(p.ctx, msg) + <-sub.validateThrottle + }(sub, msg, rch) + + default: + log.Debugf("validation throttled for topic %s", sub.topic) + throttle = true + break loop + } + } + + if throttle { + log.Warningf("message validation throttled; dropping message from %s", src) + return + } + + for _, rch := range results { + valid := <-rch + if !valid { + log.Warningf("message validation failed; dropping message from %s", src) + return + } + } + + // all validators were successful, send the message + p.sendMsg <- sendReq{ + from: src, + msg: msg, + } } func (p *PubSub) maybePublishMessage(from peer.ID, pmsg *pb.Message) { @@ -516,6 +516,13 @@ func WithValidatorTimeout(timeout time.Duration) SubOpt { } } +func WithMaxConcurrency(n int) SubOpt { + return func(sub *Subscription) error { + sub.validateThrottle = make(chan struct{}, n) + return nil + } +} + // Subscribe returns a new Subscription for the given topic func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) { td := pb.TopicDescriptor{Name: &topic} @@ -545,6 +552,10 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO } } + if sub.validate != nil && sub.validateThrottle == nil { + sub.validateThrottle = make(chan struct{}, defaultMaxConcurrency) + } + out := make(chan *Subscription, 1) p.addSub <- &addSubReq{ sub: sub, diff --git a/floodsub_test.go b/floodsub_test.go index 68357ab..351f584 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -533,17 +533,20 @@ func TestValidateOverload(t *testing.T) { for _, tc := range tcs { hosts := getNetHosts(t, ctx, 2) - psubs := getPubsubs(ctx, hosts, WithMaxConcurrency(tc.maxConcurrency)) + psubs := getPubsubs(ctx, hosts) connect(t, hosts[0], hosts[1]) topic := "foobar" block := make(chan struct{}) - sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { - <-block - return true - })) + sub, err := psubs[1].Subscribe(topic, + WithMaxConcurrency(tc.maxConcurrency), + WithValidator(func(ctx context.Context, msg *Message) bool { + <-block + return true + })) + if err != nil { t.Fatal(err) } diff --git a/subscription.go b/subscription.go index cc76413..3aa51c8 100644 --- a/subscription.go +++ b/subscription.go @@ -11,8 +11,9 @@ type Subscription struct { cancelCh chan<- *Subscription err error - validate Validator - validateTimeout time.Duration + validate Validator + validateTimeout time.Duration + validateThrottle chan struct{} } func (sub *Subscription) Topic() string { @@ -35,3 +36,24 @@ func (sub *Subscription) Next(ctx context.Context) (*Message, error) { func (sub *Subscription) Cancel() { sub.cancelCh <- sub } + +func (sub *Subscription) validateMsg(ctx context.Context, msg *Message) bool { + result := make(chan bool, 1) + vctx, cancel := context.WithTimeout(ctx, sub.validateTimeout) + defer cancel() + + go func() { + result <- sub.validate(vctx, msg) + }() + + select { + case valid := <-result: + if !valid { + log.Debugf("validation failed for topic %s", sub.topic) + } + return valid + case <-vctx.Done(): + log.Debugf("validation timeout for topic %s", sub.topic) + return false + } +} From fba445bc6ddbf4ee0caf854db45e636d0221db04 Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 14:44:33 +0200 Subject: [PATCH 14/27] code cosmetics reword pushMsg for less indentation nesting. --- floodsub.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/floodsub.go b/floodsub.go index 35a3cca..68341e3 100644 --- a/floodsub.go +++ b/floodsub.go @@ -360,19 +360,19 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { } } - if !needval { - go func() { - p.sendMsg <- sendReq{ - from: src, - msg: msg, - } - }() + if needval { + // validation is asynchronous + // XXX vyzo: do we want a global validation throttle here? + go p.validate(subs, src, msg) return } - // validation is asynchronous - // XXX vyzo: do we want a global validation throttle here? - go p.validate(subs, src, msg) + go func() { + p.sendMsg <- sendReq{ + from: src, + msg: msg, + } + }() } // validate performs validation and only sends the message if all validators succeed From c95ed2849686a49c978230124e58efcc66dea5d2 Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 18:34:00 +0200 Subject: [PATCH 15/27] add validation context for cancelation on aborts --- floodsub.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/floodsub.go b/floodsub.go index 68341e3..3ecfb33 100644 --- a/floodsub.go +++ b/floodsub.go @@ -377,6 +377,9 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { // validate performs validation and only sends the message if all validators succeed func (p *PubSub) validate(subs []*Subscription, src peer.ID, msg *Message) { + ctx, cancel := context.WithCancel(p.ctx) + defer cancel() + results := make([]chan bool, 0, len(subs)) throttle := false @@ -391,10 +394,10 @@ loop: select { case sub.validateThrottle <- struct{}{}: - go func(sub *Subscription, msg *Message, rch chan bool) { - rch <- sub.validateMsg(p.ctx, msg) + go func(sub *Subscription, rch chan bool) { + rch <- sub.validateMsg(ctx, msg) <-sub.validateThrottle - }(sub, msg, rch) + }(sub, rch) default: log.Debugf("validation throttled for topic %s", sub.topic) From 5ef13c764e885f4004c71d3f561d93c9d986f979 Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 20:11:32 +0200 Subject: [PATCH 16/27] don't always spawn a goroutine for sending a new message --- floodsub.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/floodsub.go b/floodsub.go index 3ecfb33..563f70c 100644 --- a/floodsub.go +++ b/floodsub.go @@ -367,12 +367,14 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { return } - go func() { - p.sendMsg <- sendReq{ - from: src, - msg: msg, - } - }() + sreq := sendReq{from: src, msg: msg} + select { + case p.sendMsg <- sreq: + default: + go func() { + p.sendMsg <- sreq + }() + } } // validate performs validation and only sends the message if all validators succeed From bf2151ba5f0851d3f0db5faf7708ae10d7847356 Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 20:47:28 +0200 Subject: [PATCH 17/27] the sendMsg channel should yield pointers for consistency --- floodsub.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/floodsub.go b/floodsub.go index 563f70c..627a863 100644 --- a/floodsub.go +++ b/floodsub.go @@ -59,7 +59,7 @@ type PubSub struct { topics map[string]map[peer.ID]struct{} // sendMsg handles messages that have been validated - sendMsg chan sendReq + sendMsg chan *sendReq peers map[peer.ID]chan *RPC seenMessages *timecache.TimeCache @@ -100,7 +100,7 @@ func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, err getPeers: make(chan *listPeerReq), addSub: make(chan *addSubReq), getTopics: make(chan *topicReq), - sendMsg: make(chan sendReq), + sendMsg: make(chan *sendReq), myTopics: make(map[string]map[*Subscription]struct{}), topics: make(map[string]map[peer.ID]struct{}), peers: make(map[peer.ID]chan *RPC), @@ -367,7 +367,7 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { return } - sreq := sendReq{from: src, msg: msg} + sreq := &sendReq{from: src, msg: msg} select { case p.sendMsg <- sreq: default: @@ -422,7 +422,7 @@ loop: } // all validators were successful, send the message - p.sendMsg <- sendReq{ + p.sendMsg <- &sendReq{ from: src, msg: msg, } From 856a25c8eb7b740a135d48e78cf0fac41360592c Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 20:52:38 +0200 Subject: [PATCH 18/27] WithMaxConcurrency is WithValidatorConcurrency and defaultMaxConcurrency is defaultValidateConcurrency. --- floodsub.go | 10 +++++----- floodsub_test.go | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/floodsub.go b/floodsub.go index 627a863..f32dbf4 100644 --- a/floodsub.go +++ b/floodsub.go @@ -18,9 +18,9 @@ import ( ) const ( - ID = protocol.ID("/floodsub/1.0.0") - defaultMaxConcurrency = 10 - defaultValidateTimeout = 150 * time.Millisecond + ID = protocol.ID("/floodsub/1.0.0") + defaultValidateConcurrency = 10 + defaultValidateTimeout = 150 * time.Millisecond ) var log = logging.Logger("floodsub") @@ -521,7 +521,7 @@ func WithValidatorTimeout(timeout time.Duration) SubOpt { } } -func WithMaxConcurrency(n int) SubOpt { +func WithValidatorConcurrency(n int) SubOpt { return func(sub *Subscription) error { sub.validateThrottle = make(chan struct{}, n) return nil @@ -558,7 +558,7 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO } if sub.validate != nil && sub.validateThrottle == nil { - sub.validateThrottle = make(chan struct{}, defaultMaxConcurrency) + sub.validateThrottle = make(chan struct{}, defaultValidateConcurrency) } out := make(chan *Subscription, 1) diff --git a/floodsub_test.go b/floodsub_test.go index 351f584..6d58657 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -541,7 +541,7 @@ func TestValidateOverload(t *testing.T) { block := make(chan struct{}) sub, err := psubs[1].Subscribe(topic, - WithMaxConcurrency(tc.maxConcurrency), + WithValidatorConcurrency(tc.maxConcurrency), WithValidator(func(ctx context.Context, msg *Message) bool { <-block return true @@ -554,7 +554,7 @@ func TestValidateOverload(t *testing.T) { time.Sleep(time.Millisecond * 50) if len(tc.msgs) != tc.maxConcurrency+1 { - t.Fatalf("expected number of messages sent to be defaultMaxConcurrency+1. Got %d, expected %d", len(tc.msgs), tc.maxConcurrency+1) + t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(tc.msgs), tc.maxConcurrency+1) } p := psubs[0] From edcb251ad1476c90ad0c52e675e30589099d460d Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 21:15:40 +0200 Subject: [PATCH 19/27] install global validation throttle, use reasonable defaults. --- floodsub.go | 59 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/floodsub.go b/floodsub.go index f32dbf4..45ade9c 100644 --- a/floodsub.go +++ b/floodsub.go @@ -19,8 +19,9 @@ import ( const ( ID = protocol.ID("/floodsub/1.0.0") - defaultValidateConcurrency = 10 defaultValidateTimeout = 150 * time.Millisecond + defaultValidateConcurrency = 100 + defaultValidateThrottle = 8192 ) var log = logging.Logger("floodsub") @@ -61,6 +62,9 @@ type PubSub struct { // sendMsg handles messages that have been validated sendMsg chan *sendReq + // validateThrottle limits the number of active validation goroutines + validateThrottle chan struct{} + peers map[peer.ID]chan *RPC seenMessages *timecache.TimeCache @@ -90,22 +94,23 @@ type Option func(*PubSub) error // NewFloodSub returns a new FloodSub management object func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) { ps := &PubSub{ - host: h, - ctx: ctx, - incoming: make(chan *RPC, 32), - publish: make(chan *Message), - newPeers: make(chan inet.Stream), - peerDead: make(chan peer.ID), - cancelCh: make(chan *Subscription), - getPeers: make(chan *listPeerReq), - addSub: make(chan *addSubReq), - getTopics: make(chan *topicReq), - sendMsg: make(chan *sendReq), - myTopics: make(map[string]map[*Subscription]struct{}), - topics: make(map[string]map[peer.ID]struct{}), - peers: make(map[peer.ID]chan *RPC), - seenMessages: timecache.NewTimeCache(time.Second * 30), - counter: uint64(time.Now().UnixNano()), + host: h, + ctx: ctx, + incoming: make(chan *RPC, 32), + publish: make(chan *Message), + newPeers: make(chan inet.Stream), + peerDead: make(chan peer.ID), + cancelCh: make(chan *Subscription), + getPeers: make(chan *listPeerReq), + addSub: make(chan *addSubReq), + getTopics: make(chan *topicReq), + sendMsg: make(chan *sendReq), + validateThrottle: make(chan struct{}, defaultValidateThrottle), + myTopics: make(map[string]map[*Subscription]struct{}), + topics: make(map[string]map[peer.ID]struct{}), + peers: make(map[peer.ID]chan *RPC), + seenMessages: timecache.NewTimeCache(time.Second * 30), + counter: uint64(time.Now().UnixNano()), } for _, opt := range opts { @@ -123,6 +128,13 @@ func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, err return ps, nil } +func WithValidateThrottle(n int) Option { + return func(ps *PubSub) error { + ps.validateThrottle = make(chan struct{}, n) + return nil + } +} + // processLoop handles all inputs arriving on the channels func (p *PubSub) processLoop(ctx context.Context) { defer func() { @@ -361,9 +373,16 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { } if needval { - // validation is asynchronous - // XXX vyzo: do we want a global validation throttle here? - go p.validate(subs, src, msg) + // validation is asynchronous and globally throttled with the throttleValidate semaphore + select { + case p.validateThrottle <- struct{}{}: + go func() { + p.validate(subs, src, msg) + <-p.validateThrottle + }() + default: + log.Warningf("message validation throttled; dropping message from %s", src) + } return } From 473a5d287349ce8e8506daded9c2a713d81f0dd5 Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 21:39:35 +0200 Subject: [PATCH 20/27] sendMsg should have a buffer --- floodsub.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/floodsub.go b/floodsub.go index 45ade9c..8455dad 100644 --- a/floodsub.go +++ b/floodsub.go @@ -104,7 +104,7 @@ func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, err getPeers: make(chan *listPeerReq), addSub: make(chan *addSubReq), getTopics: make(chan *topicReq), - sendMsg: make(chan *sendReq), + sendMsg: make(chan *sendReq, 32), validateThrottle: make(chan struct{}, defaultValidateThrottle), myTopics: make(map[string]map[*Subscription]struct{}), topics: make(map[string]map[peer.ID]struct{}), From f6081fb061b2a59227ed58a5678eb3c849c553c5 Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 21:56:57 +0200 Subject: [PATCH 21/27] pushMsg should just call maybePublishMessage when it doesn't need validation --- floodsub.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/floodsub.go b/floodsub.go index 8455dad..c3f6aa5 100644 --- a/floodsub.go +++ b/floodsub.go @@ -386,14 +386,7 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { return } - sreq := &sendReq{from: src, msg: msg} - select { - case p.sendMsg <- sreq: - default: - go func() { - p.sendMsg <- sreq - }() - } + p.maybePublishMessage(src, msg.Message) } // validate performs validation and only sends the message if all validators succeed From 145a84a33b05c79819aa08b1752475c5c0c23e03 Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 22:02:28 +0200 Subject: [PATCH 22/27] use a single channel for all validation results --- floodsub.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/floodsub.go b/floodsub.go index c3f6aa5..5e2fea5 100644 --- a/floodsub.go +++ b/floodsub.go @@ -394,7 +394,8 @@ func (p *PubSub) validate(subs []*Subscription, src peer.ID, msg *Message) { ctx, cancel := context.WithCancel(p.ctx) defer cancel() - results := make([]chan bool, 0, len(subs)) + rch := make(chan bool, len(subs)) + rcount := 0 throttle := false loop: @@ -403,15 +404,14 @@ loop: continue } - rch := make(chan bool, 1) - results = append(results, rch) + rcount++ select { case sub.validateThrottle <- struct{}{}: - go func(sub *Subscription, rch chan bool) { + go func(sub *Subscription) { rch <- sub.validateMsg(ctx, msg) <-sub.validateThrottle - }(sub, rch) + }(sub) default: log.Debugf("validation throttled for topic %s", sub.topic) @@ -425,7 +425,7 @@ loop: return } - for _, rch := range results { + for i := 0; i < rcount; i++ { valid := <-rch if !valid { log.Warningf("message validation failed; dropping message from %s", src) From f1be0f12966b357474b260166482a3892a192601 Mon Sep 17 00:00:00 2001 From: vyzo Date: Sat, 13 Jan 2018 22:14:01 +0200 Subject: [PATCH 23/27] don't spawn an extra goroutine for the validator context --- subscription.go | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/subscription.go b/subscription.go index 3aa51c8..a15466a 100644 --- a/subscription.go +++ b/subscription.go @@ -38,22 +38,13 @@ func (sub *Subscription) Cancel() { } func (sub *Subscription) validateMsg(ctx context.Context, msg *Message) bool { - result := make(chan bool, 1) vctx, cancel := context.WithTimeout(ctx, sub.validateTimeout) defer cancel() - go func() { - result <- sub.validate(vctx, msg) - }() - - select { - case valid := <-result: - if !valid { - log.Debugf("validation failed for topic %s", sub.topic) - } - return valid - case <-vctx.Done(): - log.Debugf("validation timeout for topic %s", sub.topic) - return false + valid := sub.validate(vctx, msg) + if !valid { + log.Debugf("validation failed for topic %s", sub.topic) } + + return valid } From cb365a5feef21ee1b618f8fb6a3652a7c3dbd2e2 Mon Sep 17 00:00:00 2001 From: vyzo Date: Sun, 14 Jan 2018 02:01:42 +0200 Subject: [PATCH 24/27] remove faulty tests --- floodsub_test.go | 98 ------------------------------------------------ 1 file changed, 98 deletions(-) diff --git a/floodsub_test.go b/floodsub_test.go index 6d58657..1bf1847 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -392,104 +392,6 @@ func TestValidate(t *testing.T) { } } -func TestValidateTimeout(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - hosts := getNetHosts(t, ctx, 2) - psubs := getPubsubs(ctx, hosts) - - connect(t, hosts[0], hosts[1]) - topic := "foobar" - - cases := []struct { - timeout time.Duration - msg []byte - validates bool - }{ - {75 * time.Millisecond, []byte("this better time out"), false}, - {150 * time.Millisecond, []byte("this should work"), true}, - } - - for _, tc := range cases { - sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { - time.Sleep(100 * time.Millisecond) - return true - }), WithValidatorTimeout(tc.timeout)) - if err != nil { - t.Fatal(err) - } - - time.Sleep(time.Millisecond * 50) - - p := psubs[0] - err = p.Publish(topic, tc.msg) - if err != nil { - t.Fatal(err) - } - - select { - case msg := <-sub.ch: - if !tc.validates { - t.Log(msg) - t.Error("expected message validation to filter out the message") - } - case <-time.After(333 * time.Millisecond): - if tc.validates { - t.Error("expected message validation to accept the message") - } - } - - // important: cancel! - // otherwise the message will still be filtered by the other subscription - sub.Cancel() - } - -} - -func TestValidateCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - hosts := getNetHosts(t, ctx, 2) - psubs := getPubsubs(ctx, hosts) - - connect(t, hosts[0], hosts[1]) - topic := "foobar" - - sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { - <-ctx.Done() - return true - })) - if err != nil { - t.Fatal(err) - } - - time.Sleep(time.Millisecond * 50) - - testmsg := []byte("this is a legal message") - validates := false // message for which the validator times our are discarded - - p := psubs[0] - - err = p.Publish(topic, testmsg) - if err != nil { - t.Fatal(err) - } - - select { - case msg := <-sub.ch: - if !validates { - t.Log(msg) - t.Error("expected message validation to filter out the message") - } - case <-time.After(333 * time.Millisecond): - if validates { - t.Error("expected message validation to accept the message") - } - } -} - func TestValidateOverload(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() From fceb00d2346c7cc4541e66201691259d4ed45e70 Mon Sep 17 00:00:00 2001 From: vyzo Date: Sun, 14 Jan 2018 02:24:13 +0200 Subject: [PATCH 25/27] improved comment about global validation throttle --- floodsub.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/floodsub.go b/floodsub.go index 5e2fea5..2f4df69 100644 --- a/floodsub.go +++ b/floodsub.go @@ -373,7 +373,10 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { } if needval { - // validation is asynchronous and globally throttled with the throttleValidate semaphore + // validation is asynchronous and globally throttled with the throttleValidate semaphore. + // the purpose of the global throttle is to bound the goncurrency possible from incoming + // network traffic; each subscription also has an individual throttle to preclude + // slow (or faulty) validators from starving other topics; see validate below. select { case p.validateThrottle <- struct{}{}: go func() { From bbdec3fda2e29840bc68b1ead82ec965576bd968 Mon Sep 17 00:00:00 2001 From: vyzo Date: Thu, 18 Jan 2018 19:12:36 +0200 Subject: [PATCH 26/27] implement per topic validators --- floodsub.go | 204 ++++++++++++++++++++++++++++++----------------- floodsub_test.go | 22 +++-- subscription.go | 17 ---- 3 files changed, 149 insertions(+), 94 deletions(-) diff --git a/floodsub.go b/floodsub.go index 2f4df69..ab2c6b2 100644 --- a/floodsub.go +++ b/floodsub.go @@ -62,6 +62,12 @@ type PubSub struct { // sendMsg handles messages that have been validated sendMsg chan *sendReq + // addVal handles validator registration requests + addVal chan *addValReq + + // topicVals tracks per topic validators + topicVals map[string]*topicVal + // validateThrottle limits the number of active validation goroutines validateThrottle chan struct{} @@ -105,10 +111,12 @@ func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, err addSub: make(chan *addSubReq), getTopics: make(chan *topicReq), sendMsg: make(chan *sendReq, 32), + addVal: make(chan *addValReq), validateThrottle: make(chan struct{}, defaultValidateThrottle), myTopics: make(map[string]map[*Subscription]struct{}), topics: make(map[string]map[peer.ID]struct{}), peers: make(map[peer.ID]chan *RPC), + topicVals: make(map[string]*topicVal), seenMessages: timecache.NewTimeCache(time.Second * 30), counter: uint64(time.Now().UnixNano()), } @@ -205,12 +213,15 @@ func (p *PubSub) processLoop(ctx context.Context) { continue } case msg := <-p.publish: - subs := p.getSubscriptions(msg) - p.pushMsg(subs, p.host.ID(), msg) + vals := p.getValidators(msg) + p.pushMsg(vals, p.host.ID(), msg) case req := <-p.sendMsg: p.maybePublishMessage(req.from, req.msg.Message) + case req := <-p.addVal: + p.addValidator(req) + case <-ctx.Done(): log.Info("pubsub processloop shutting down") return @@ -347,8 +358,8 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error { } msg := &Message{pmsg} - subs := p.getSubscriptions(msg) - p.pushMsg(subs, rpc.from, msg) + vals := p.getValidators(msg) + p.pushMsg(vals, rpc.from, msg) } return nil @@ -359,20 +370,9 @@ func msgID(pmsg *pb.Message) string { return string(pmsg.GetFrom()) + string(pmsg.GetSeqno()) } -// pushMsg pushes a message to a number of subscriptions, performing validation -// as necessary -func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { - // we perform validation if _any_ of the subscriptions has a validator - // because the message is sent once for all topics - needval := false - for _, sub := range subs { - if sub.validate != nil { - needval = true - break - } - } - - if needval { +// pushMsg pushes a message performing validation as necessary +func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) { + if len(vals) > 0 { // validation is asynchronous and globally throttled with the throttleValidate semaphore. // the purpose of the global throttle is to bound the goncurrency possible from incoming // network traffic; each subscription also has an individual throttle to preclude @@ -380,7 +380,7 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { select { case p.validateThrottle <- struct{}{}: go func() { - p.validate(subs, src, msg) + p.validate(vals, src, msg) <-p.validateThrottle }() default: @@ -393,31 +393,27 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { } // validate performs validation and only sends the message if all validators succeed -func (p *PubSub) validate(subs []*Subscription, src peer.ID, msg *Message) { +func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) { ctx, cancel := context.WithCancel(p.ctx) defer cancel() - rch := make(chan bool, len(subs)) + rch := make(chan bool, len(vals)) rcount := 0 throttle := false loop: - for _, sub := range subs { - if sub.validate == nil { - continue - } - + for _, val := range vals { rcount++ select { - case sub.validateThrottle <- struct{}{}: - go func(sub *Subscription) { - rch <- sub.validateMsg(ctx, msg) - <-sub.validateThrottle - }(sub) + case val.validateThrottle <- struct{}{}: + go func(val *topicVal) { + rch <- val.validateMsg(ctx, msg) + <-val.validateThrottle + }(val) default: - log.Debugf("validation throttled for topic %s", sub.topic) + log.Debugf("validation throttled for topic %s", val.topic) throttle = true break loop } @@ -494,22 +490,20 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error { return nil } -// getSubscriptions returns all subscriptions the would receive the given message. -func (p *PubSub) getSubscriptions(msg *Message) []*Subscription { - var subs []*Subscription +// getValidators returns all validators that apply to a given message +func (p *PubSub) getValidators(msg *Message) []*topicVal { + var vals []*topicVal for _, topic := range msg.GetTopicIDs() { - tSubs, ok := p.myTopics[topic] + val, ok := p.topicVals[topic] if !ok { continue } - for sub := range tSubs { - subs = append(subs, sub) - } + vals = append(vals, val) } - return subs + return vals } type addSubReq struct { @@ -517,31 +511,7 @@ type addSubReq struct { resp chan *Subscription } -type SubOpt func(*Subscription) error -type Validator func(context.Context, *Message) bool - -// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further. -func WithValidator(validate Validator) SubOpt { - return func(sub *Subscription) error { - sub.validate = validate - return nil - } -} - -// WithValidatorTimeout is an option that can be supplied to Subscribe. The argument is a duration after which long-running validators are canceled. -func WithValidatorTimeout(timeout time.Duration) SubOpt { - return func(sub *Subscription) error { - sub.validateTimeout = timeout - return nil - } -} - -func WithValidatorConcurrency(n int) SubOpt { - return func(sub *Subscription) error { - sub.validateThrottle = make(chan struct{}, n) - return nil - } -} +type SubOpt func(sub *Subscription) error // Subscribe returns a new Subscription for the given topic func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) { @@ -561,8 +531,7 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO } sub := &Subscription{ - topic: td.GetName(), - validateTimeout: defaultValidateTimeout, + topic: td.GetName(), } for _, opt := range opts { @@ -572,10 +541,6 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO } } - if sub.validate != nil && sub.validateThrottle == nil { - sub.validateThrottle = make(chan struct{}, defaultValidateConcurrency) - } - out := make(chan *Subscription, 1) p.addSub <- &addSubReq{ sub: sub, @@ -633,3 +598,100 @@ func (p *PubSub) ListPeers(topic string) []peer.ID { } return <-out } + +// per topic validators +type addValReq struct { + topic string + validate Validator + timeout time.Duration + throttle int + resp chan error +} + +type topicVal struct { + topic string + validate Validator + validateTimeout time.Duration + validateThrottle chan struct{} +} + +// Validator is a function that validates a message +type Validator func(context.Context, *Message) bool + +// ValidatorOpt is an option for RegisterTopicValidator +type ValidatorOpt func(addVal *addValReq) error + +// WithValidatorTimeout is an option that sets the topic validator timeout +func WithValidatorTimeout(timeout time.Duration) ValidatorOpt { + return func(addVal *addValReq) error { + addVal.timeout = timeout + return nil + } +} + +// WithValidatorConcurrency is an option that sets topic validator throttle +func WithValidatorConcurrency(n int) ValidatorOpt { + return func(addVal *addValReq) error { + addVal.throttle = n + return nil + } +} + +// RegisterTopicValidator registers a validator for topic +func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...ValidatorOpt) error { + addVal := &addValReq{ + topic: topic, + validate: val, + resp: make(chan error, 1), + } + + for _, opt := range opts { + err := opt(addVal) + if err != nil { + return err + } + } + + p.addVal <- addVal + return <-addVal.resp +} + +func (ps *PubSub) addValidator(req *addValReq) { + topic := req.topic + + _, ok := ps.topicVals[topic] + if ok { + req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic) + return + } + + val := &topicVal{ + topic: topic, + validate: req.validate, + validateTimeout: defaultValidateTimeout, + validateThrottle: make(chan struct{}, defaultValidateConcurrency), + } + + if req.timeout > 0 { + val.validateTimeout = req.timeout + } + + if req.throttle > 0 { + val.validateThrottle = make(chan struct{}, req.throttle) + } + + ps.topicVals[topic] = val + req.resp <- nil +} + +func (val *topicVal) validateMsg(ctx context.Context, msg *Message) bool { + vctx, cancel := context.WithTimeout(ctx, val.validateTimeout) + defer cancel() + + valid := val.validate(vctx, msg) + if !valid { + log.Debugf("validation failed for topic %s", val.topic) + } + + return valid +} diff --git a/floodsub_test.go b/floodsub_test.go index 1bf1847..3d6a68a 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -351,9 +351,14 @@ func TestValidate(t *testing.T) { connect(t, hosts[0], hosts[1]) topic := "foobar" - sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { + err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, msg *Message) bool { return !bytes.Contains(msg.Data, []byte("illegal")) - })) + }) + if err != nil { + t.Fatal(err) + } + + sub, err := psubs[1].Subscribe(topic) if err != nil { t.Fatal(err) } @@ -442,17 +447,22 @@ func TestValidateOverload(t *testing.T) { block := make(chan struct{}) - sub, err := psubs[1].Subscribe(topic, - WithValidatorConcurrency(tc.maxConcurrency), - WithValidator(func(ctx context.Context, msg *Message) bool { + err := psubs[1].RegisterTopicValidator(topic, + func(ctx context.Context, msg *Message) bool { <-block return true - })) + }, + WithValidatorConcurrency(tc.maxConcurrency)) if err != nil { t.Fatal(err) } + sub, err := psubs[1].Subscribe(topic) + if err != nil { + t.Fatal(err) + } + time.Sleep(time.Millisecond * 50) if len(tc.msgs) != tc.maxConcurrency+1 { diff --git a/subscription.go b/subscription.go index a15466a..d6e930c 100644 --- a/subscription.go +++ b/subscription.go @@ -2,7 +2,6 @@ package floodsub import ( "context" - "time" ) type Subscription struct { @@ -10,10 +9,6 @@ type Subscription struct { ch chan *Message cancelCh chan<- *Subscription err error - - validate Validator - validateTimeout time.Duration - validateThrottle chan struct{} } func (sub *Subscription) Topic() string { @@ -36,15 +31,3 @@ func (sub *Subscription) Next(ctx context.Context) (*Message, error) { func (sub *Subscription) Cancel() { sub.cancelCh <- sub } - -func (sub *Subscription) validateMsg(ctx context.Context, msg *Message) bool { - vctx, cancel := context.WithTimeout(ctx, sub.validateTimeout) - defer cancel() - - valid := sub.validate(vctx, msg) - if !valid { - log.Debugf("validation failed for topic %s", sub.topic) - } - - return valid -} From 3f4fc21228683abc06950729736bcd1b09dc95ff Mon Sep 17 00:00:00 2001 From: vyzo Date: Thu, 18 Jan 2018 20:09:09 +0200 Subject: [PATCH 27/27] fix comment, subscriptions don't have validators any more. --- floodsub.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/floodsub.go b/floodsub.go index ab2c6b2..3ba607e 100644 --- a/floodsub.go +++ b/floodsub.go @@ -375,7 +375,7 @@ func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) { if len(vals) > 0 { // validation is asynchronous and globally throttled with the throttleValidate semaphore. // the purpose of the global throttle is to bound the goncurrency possible from incoming - // network traffic; each subscription also has an individual throttle to preclude + // network traffic; each validator also has an individual throttle to preclude // slow (or faulty) validators from starving other topics; see validate below. select { case p.validateThrottle <- struct{}{}: