diff --git a/floodsub.go b/floodsub.go index 907fb47..35a3cca 100644 --- a/floodsub.go +++ b/floodsub.go @@ -61,9 +61,6 @@ type PubSub struct { // sendMsg handles messages that have been validated sendMsg chan sendReq - // throttleValidate bounds the number of goroutines concurrently validating messages - throttleValidate chan struct{} - peers map[peer.ID]chan *RPC seenMessages *timecache.TimeCache @@ -90,33 +87,25 @@ type RPC struct { type Option func(*PubSub) error -func WithMaxConcurrency(n int) Option { - return func(ps *PubSub) error { - ps.throttleValidate = make(chan struct{}, n) - return nil - } -} - // NewFloodSub returns a new FloodSub management object func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) { ps := &PubSub{ - host: h, - ctx: ctx, - incoming: make(chan *RPC, 32), - publish: make(chan *Message), - newPeers: make(chan inet.Stream), - peerDead: make(chan peer.ID), - cancelCh: make(chan *Subscription), - getPeers: make(chan *listPeerReq), - addSub: make(chan *addSubReq), - getTopics: make(chan *topicReq), - sendMsg: make(chan sendReq), - myTopics: make(map[string]map[*Subscription]struct{}), - topics: make(map[string]map[peer.ID]struct{}), - peers: make(map[peer.ID]chan *RPC), - seenMessages: timecache.NewTimeCache(time.Second * 30), - counter: uint64(time.Now().UnixNano()), - throttleValidate: make(chan struct{}, defaultMaxConcurrency), + host: h, + ctx: ctx, + incoming: make(chan *RPC, 32), + publish: make(chan *Message), + newPeers: make(chan inet.Stream), + peerDead: make(chan peer.ID), + cancelCh: make(chan *Subscription), + getPeers: make(chan *listPeerReq), + addSub: make(chan *addSubReq), + getTopics: make(chan *topicReq), + sendMsg: make(chan sendReq), + myTopics: make(map[string]map[*Subscription]struct{}), + topics: make(map[string]map[peer.ID]struct{}), + peers: make(map[peer.ID]chan *RPC), + seenMessages: timecache.NewTimeCache(time.Second * 30), + counter: uint64(time.Now().UnixNano()), } for _, opt := range opts { @@ -204,24 +193,9 @@ func (p *PubSub) processLoop(ctx context.Context) { continue } case msg := <-p.publish: - subs := p.getSubscriptions(msg) // call before goroutine! + subs := p.getSubscriptions(msg) + p.pushMsg(subs, p.host.ID(), msg) - select { - case p.throttleValidate <- struct{}{}: - go func(msg *Message) { - defer func() { <-p.throttleValidate }() - - if p.validate(subs, msg) { - p.sendMsg <- sendReq{ - from: p.host.ID(), - msg: msg, - } - - } - }(msg) - default: - log.Warning("could not acquire validator; dropping message") - } case req := <-p.sendMsg: p.maybePublishMessage(req.from, req.msg.Message) @@ -360,24 +334,11 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error { continue } - subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine! - - select { - case p.throttleValidate <- struct{}{}: - go func(pmsg *pb.Message) { - defer func() { <-p.throttleValidate }() - - if p.validate(subs, &Message{pmsg}) { - p.sendMsg <- sendReq{ - from: rpc.from, - msg: &Message{pmsg}, - } - } - }(pmsg) - default: - log.Warning("could not acquire validator; dropping message") - } + msg := &Message{pmsg} + subs := p.getSubscriptions(msg) + p.pushMsg(subs, rpc.from, msg) } + return nil } @@ -386,41 +347,80 @@ func msgID(pmsg *pb.Message) string { return string(pmsg.GetFrom()) + string(pmsg.GetSeqno()) } -// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter. -func (p *PubSub) validate(subs []*Subscription, msg *Message) bool { - results := make([]chan bool, len(subs)) - ctxs := make([]context.Context, len(subs)) - - for i, sub := range subs { - result := make(chan bool) - ctx, cancel := context.WithTimeout(p.ctx, sub.validateTimeout) - defer cancel() - - ctxs[i] = ctx - results[i] = result - - go func(sub *Subscription) { - result <- sub.validate == nil || sub.validate(ctx, msg) - }(sub) - } - - for i, sub := range subs { - ctx := ctxs[i] - result := results[i] - - select { - case valid := <-result: - if !valid { - log.Debugf("validator for topic %s returned false", sub.topic) - return false - } - case <-ctx.Done(): - log.Debugf("validator for topic %s timed out. msg: %s", sub.topic, msg) - return false +// pushMsg pushes a message to a number of subscriptions, performing validation +// as necessary +func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { + // we perform validation if _any_ of the subscriptions has a validator + // because the message is sent once for all topics + needval := false + for _, sub := range subs { + if sub.validate != nil { + needval = true + break } } - return true + if !needval { + go func() { + p.sendMsg <- sendReq{ + from: src, + msg: msg, + } + }() + return + } + + // validation is asynchronous + // XXX vyzo: do we want a global validation throttle here? + go p.validate(subs, src, msg) +} + +// validate performs validation and only sends the message if all validators succeed +func (p *PubSub) validate(subs []*Subscription, src peer.ID, msg *Message) { + results := make([]chan bool, 0, len(subs)) + throttle := false + +loop: + for _, sub := range subs { + if sub.validate == nil { + continue + } + + rch := make(chan bool, 1) + results = append(results, rch) + + select { + case sub.validateThrottle <- struct{}{}: + go func(sub *Subscription, msg *Message, rch chan bool) { + rch <- sub.validateMsg(p.ctx, msg) + <-sub.validateThrottle + }(sub, msg, rch) + + default: + log.Debugf("validation throttled for topic %s", sub.topic) + throttle = true + break loop + } + } + + if throttle { + log.Warningf("message validation throttled; dropping message from %s", src) + return + } + + for _, rch := range results { + valid := <-rch + if !valid { + log.Warningf("message validation failed; dropping message from %s", src) + return + } + } + + // all validators were successful, send the message + p.sendMsg <- sendReq{ + from: src, + msg: msg, + } } func (p *PubSub) maybePublishMessage(from peer.ID, pmsg *pb.Message) { @@ -516,6 +516,13 @@ func WithValidatorTimeout(timeout time.Duration) SubOpt { } } +func WithMaxConcurrency(n int) SubOpt { + return func(sub *Subscription) error { + sub.validateThrottle = make(chan struct{}, n) + return nil + } +} + // Subscribe returns a new Subscription for the given topic func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) { td := pb.TopicDescriptor{Name: &topic} @@ -545,6 +552,10 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO } } + if sub.validate != nil && sub.validateThrottle == nil { + sub.validateThrottle = make(chan struct{}, defaultMaxConcurrency) + } + out := make(chan *Subscription, 1) p.addSub <- &addSubReq{ sub: sub, diff --git a/floodsub_test.go b/floodsub_test.go index 68357ab..351f584 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -533,17 +533,20 @@ func TestValidateOverload(t *testing.T) { for _, tc := range tcs { hosts := getNetHosts(t, ctx, 2) - psubs := getPubsubs(ctx, hosts, WithMaxConcurrency(tc.maxConcurrency)) + psubs := getPubsubs(ctx, hosts) connect(t, hosts[0], hosts[1]) topic := "foobar" block := make(chan struct{}) - sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { - <-block - return true - })) + sub, err := psubs[1].Subscribe(topic, + WithMaxConcurrency(tc.maxConcurrency), + WithValidator(func(ctx context.Context, msg *Message) bool { + <-block + return true + })) + if err != nil { t.Fatal(err) } diff --git a/subscription.go b/subscription.go index cc76413..3aa51c8 100644 --- a/subscription.go +++ b/subscription.go @@ -11,8 +11,9 @@ type Subscription struct { cancelCh chan<- *Subscription err error - validate Validator - validateTimeout time.Duration + validate Validator + validateTimeout time.Duration + validateThrottle chan struct{} } func (sub *Subscription) Topic() string { @@ -35,3 +36,24 @@ func (sub *Subscription) Next(ctx context.Context) (*Message, error) { func (sub *Subscription) Cancel() { sub.cancelCh <- sub } + +func (sub *Subscription) validateMsg(ctx context.Context, msg *Message) bool { + result := make(chan bool, 1) + vctx, cancel := context.WithTimeout(ctx, sub.validateTimeout) + defer cancel() + + go func() { + result <- sub.validate(vctx, msg) + }() + + select { + case valid := <-result: + if !valid { + log.Debugf("validation failed for topic %s", sub.topic) + } + return valid + case <-vctx.Done(): + log.Debugf("validation timeout for topic %s", sub.topic) + return false + } +}