From bbdec3fda2e29840bc68b1ead82ec965576bd968 Mon Sep 17 00:00:00 2001 From: vyzo Date: Thu, 18 Jan 2018 19:12:36 +0200 Subject: [PATCH] implement per topic validators --- floodsub.go | 204 ++++++++++++++++++++++++++++++----------------- floodsub_test.go | 22 +++-- subscription.go | 17 ---- 3 files changed, 149 insertions(+), 94 deletions(-) diff --git a/floodsub.go b/floodsub.go index 2f4df69..ab2c6b2 100644 --- a/floodsub.go +++ b/floodsub.go @@ -62,6 +62,12 @@ type PubSub struct { // sendMsg handles messages that have been validated sendMsg chan *sendReq + // addVal handles validator registration requests + addVal chan *addValReq + + // topicVals tracks per topic validators + topicVals map[string]*topicVal + // validateThrottle limits the number of active validation goroutines validateThrottle chan struct{} @@ -105,10 +111,12 @@ func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, err addSub: make(chan *addSubReq), getTopics: make(chan *topicReq), sendMsg: make(chan *sendReq, 32), + addVal: make(chan *addValReq), validateThrottle: make(chan struct{}, defaultValidateThrottle), myTopics: make(map[string]map[*Subscription]struct{}), topics: make(map[string]map[peer.ID]struct{}), peers: make(map[peer.ID]chan *RPC), + topicVals: make(map[string]*topicVal), seenMessages: timecache.NewTimeCache(time.Second * 30), counter: uint64(time.Now().UnixNano()), } @@ -205,12 +213,15 @@ func (p *PubSub) processLoop(ctx context.Context) { continue } case msg := <-p.publish: - subs := p.getSubscriptions(msg) - p.pushMsg(subs, p.host.ID(), msg) + vals := p.getValidators(msg) + p.pushMsg(vals, p.host.ID(), msg) case req := <-p.sendMsg: p.maybePublishMessage(req.from, req.msg.Message) + case req := <-p.addVal: + p.addValidator(req) + case <-ctx.Done(): log.Info("pubsub processloop shutting down") return @@ -347,8 +358,8 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error { } msg := &Message{pmsg} - subs := p.getSubscriptions(msg) - p.pushMsg(subs, rpc.from, msg) + vals := p.getValidators(msg) + p.pushMsg(vals, rpc.from, msg) } return nil @@ -359,20 +370,9 @@ func msgID(pmsg *pb.Message) string { return string(pmsg.GetFrom()) + string(pmsg.GetSeqno()) } -// pushMsg pushes a message to a number of subscriptions, performing validation -// as necessary -func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { - // we perform validation if _any_ of the subscriptions has a validator - // because the message is sent once for all topics - needval := false - for _, sub := range subs { - if sub.validate != nil { - needval = true - break - } - } - - if needval { +// pushMsg pushes a message performing validation as necessary +func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) { + if len(vals) > 0 { // validation is asynchronous and globally throttled with the throttleValidate semaphore. // the purpose of the global throttle is to bound the goncurrency possible from incoming // network traffic; each subscription also has an individual throttle to preclude @@ -380,7 +380,7 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { select { case p.validateThrottle <- struct{}{}: go func() { - p.validate(subs, src, msg) + p.validate(vals, src, msg) <-p.validateThrottle }() default: @@ -393,31 +393,27 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) { } // validate performs validation and only sends the message if all validators succeed -func (p *PubSub) validate(subs []*Subscription, src peer.ID, msg *Message) { +func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) { ctx, cancel := context.WithCancel(p.ctx) defer cancel() - rch := make(chan bool, len(subs)) + rch := make(chan bool, len(vals)) rcount := 0 throttle := false loop: - for _, sub := range subs { - if sub.validate == nil { - continue - } - + for _, val := range vals { rcount++ select { - case sub.validateThrottle <- struct{}{}: - go func(sub *Subscription) { - rch <- sub.validateMsg(ctx, msg) - <-sub.validateThrottle - }(sub) + case val.validateThrottle <- struct{}{}: + go func(val *topicVal) { + rch <- val.validateMsg(ctx, msg) + <-val.validateThrottle + }(val) default: - log.Debugf("validation throttled for topic %s", sub.topic) + log.Debugf("validation throttled for topic %s", val.topic) throttle = true break loop } @@ -494,22 +490,20 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error { return nil } -// getSubscriptions returns all subscriptions the would receive the given message. -func (p *PubSub) getSubscriptions(msg *Message) []*Subscription { - var subs []*Subscription +// getValidators returns all validators that apply to a given message +func (p *PubSub) getValidators(msg *Message) []*topicVal { + var vals []*topicVal for _, topic := range msg.GetTopicIDs() { - tSubs, ok := p.myTopics[topic] + val, ok := p.topicVals[topic] if !ok { continue } - for sub := range tSubs { - subs = append(subs, sub) - } + vals = append(vals, val) } - return subs + return vals } type addSubReq struct { @@ -517,31 +511,7 @@ type addSubReq struct { resp chan *Subscription } -type SubOpt func(*Subscription) error -type Validator func(context.Context, *Message) bool - -// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further. -func WithValidator(validate Validator) SubOpt { - return func(sub *Subscription) error { - sub.validate = validate - return nil - } -} - -// WithValidatorTimeout is an option that can be supplied to Subscribe. The argument is a duration after which long-running validators are canceled. -func WithValidatorTimeout(timeout time.Duration) SubOpt { - return func(sub *Subscription) error { - sub.validateTimeout = timeout - return nil - } -} - -func WithValidatorConcurrency(n int) SubOpt { - return func(sub *Subscription) error { - sub.validateThrottle = make(chan struct{}, n) - return nil - } -} +type SubOpt func(sub *Subscription) error // Subscribe returns a new Subscription for the given topic func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) { @@ -561,8 +531,7 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO } sub := &Subscription{ - topic: td.GetName(), - validateTimeout: defaultValidateTimeout, + topic: td.GetName(), } for _, opt := range opts { @@ -572,10 +541,6 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO } } - if sub.validate != nil && sub.validateThrottle == nil { - sub.validateThrottle = make(chan struct{}, defaultValidateConcurrency) - } - out := make(chan *Subscription, 1) p.addSub <- &addSubReq{ sub: sub, @@ -633,3 +598,100 @@ func (p *PubSub) ListPeers(topic string) []peer.ID { } return <-out } + +// per topic validators +type addValReq struct { + topic string + validate Validator + timeout time.Duration + throttle int + resp chan error +} + +type topicVal struct { + topic string + validate Validator + validateTimeout time.Duration + validateThrottle chan struct{} +} + +// Validator is a function that validates a message +type Validator func(context.Context, *Message) bool + +// ValidatorOpt is an option for RegisterTopicValidator +type ValidatorOpt func(addVal *addValReq) error + +// WithValidatorTimeout is an option that sets the topic validator timeout +func WithValidatorTimeout(timeout time.Duration) ValidatorOpt { + return func(addVal *addValReq) error { + addVal.timeout = timeout + return nil + } +} + +// WithValidatorConcurrency is an option that sets topic validator throttle +func WithValidatorConcurrency(n int) ValidatorOpt { + return func(addVal *addValReq) error { + addVal.throttle = n + return nil + } +} + +// RegisterTopicValidator registers a validator for topic +func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...ValidatorOpt) error { + addVal := &addValReq{ + topic: topic, + validate: val, + resp: make(chan error, 1), + } + + for _, opt := range opts { + err := opt(addVal) + if err != nil { + return err + } + } + + p.addVal <- addVal + return <-addVal.resp +} + +func (ps *PubSub) addValidator(req *addValReq) { + topic := req.topic + + _, ok := ps.topicVals[topic] + if ok { + req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic) + return + } + + val := &topicVal{ + topic: topic, + validate: req.validate, + validateTimeout: defaultValidateTimeout, + validateThrottle: make(chan struct{}, defaultValidateConcurrency), + } + + if req.timeout > 0 { + val.validateTimeout = req.timeout + } + + if req.throttle > 0 { + val.validateThrottle = make(chan struct{}, req.throttle) + } + + ps.topicVals[topic] = val + req.resp <- nil +} + +func (val *topicVal) validateMsg(ctx context.Context, msg *Message) bool { + vctx, cancel := context.WithTimeout(ctx, val.validateTimeout) + defer cancel() + + valid := val.validate(vctx, msg) + if !valid { + log.Debugf("validation failed for topic %s", val.topic) + } + + return valid +} diff --git a/floodsub_test.go b/floodsub_test.go index 1bf1847..3d6a68a 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -351,9 +351,14 @@ func TestValidate(t *testing.T) { connect(t, hosts[0], hosts[1]) topic := "foobar" - sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { + err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, msg *Message) bool { return !bytes.Contains(msg.Data, []byte("illegal")) - })) + }) + if err != nil { + t.Fatal(err) + } + + sub, err := psubs[1].Subscribe(topic) if err != nil { t.Fatal(err) } @@ -442,17 +447,22 @@ func TestValidateOverload(t *testing.T) { block := make(chan struct{}) - sub, err := psubs[1].Subscribe(topic, - WithValidatorConcurrency(tc.maxConcurrency), - WithValidator(func(ctx context.Context, msg *Message) bool { + err := psubs[1].RegisterTopicValidator(topic, + func(ctx context.Context, msg *Message) bool { <-block return true - })) + }, + WithValidatorConcurrency(tc.maxConcurrency)) if err != nil { t.Fatal(err) } + sub, err := psubs[1].Subscribe(topic) + if err != nil { + t.Fatal(err) + } + time.Sleep(time.Millisecond * 50) if len(tc.msgs) != tc.maxConcurrency+1 { diff --git a/subscription.go b/subscription.go index a15466a..d6e930c 100644 --- a/subscription.go +++ b/subscription.go @@ -2,7 +2,6 @@ package floodsub import ( "context" - "time" ) type Subscription struct { @@ -10,10 +9,6 @@ type Subscription struct { ch chan *Message cancelCh chan<- *Subscription err error - - validate Validator - validateTimeout time.Duration - validateThrottle chan struct{} } func (sub *Subscription) Topic() string { @@ -36,15 +31,3 @@ func (sub *Subscription) Next(ctx context.Context) (*Message, error) { func (sub *Subscription) Cancel() { sub.cancelCh <- sub } - -func (sub *Subscription) validateMsg(ctx context.Context, msg *Message) bool { - vctx, cancel := context.WithTimeout(ctx, sub.validateTimeout) - defer cancel() - - valid := sub.validate(vctx, msg) - if !valid { - log.Debugf("validation failed for topic %s", sub.topic) - } - - return valid -}