implement per topic validators

This commit is contained in:
vyzo 2018-01-18 19:12:36 +02:00
parent fceb00d234
commit bbdec3fda2
3 changed files with 149 additions and 94 deletions

View File

@ -62,6 +62,12 @@ type PubSub struct {
// sendMsg handles messages that have been validated
sendMsg chan *sendReq
// addVal handles validator registration requests
addVal chan *addValReq
// topicVals tracks per topic validators
topicVals map[string]*topicVal
// validateThrottle limits the number of active validation goroutines
validateThrottle chan struct{}
@ -105,10 +111,12 @@ func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, err
addSub: make(chan *addSubReq),
getTopics: make(chan *topicReq),
sendMsg: make(chan *sendReq, 32),
addVal: make(chan *addValReq),
validateThrottle: make(chan struct{}, defaultValidateThrottle),
myTopics: make(map[string]map[*Subscription]struct{}),
topics: make(map[string]map[peer.ID]struct{}),
peers: make(map[peer.ID]chan *RPC),
topicVals: make(map[string]*topicVal),
seenMessages: timecache.NewTimeCache(time.Second * 30),
counter: uint64(time.Now().UnixNano()),
}
@ -205,12 +213,15 @@ func (p *PubSub) processLoop(ctx context.Context) {
continue
}
case msg := <-p.publish:
subs := p.getSubscriptions(msg)
p.pushMsg(subs, p.host.ID(), msg)
vals := p.getValidators(msg)
p.pushMsg(vals, p.host.ID(), msg)
case req := <-p.sendMsg:
p.maybePublishMessage(req.from, req.msg.Message)
case req := <-p.addVal:
p.addValidator(req)
case <-ctx.Done():
log.Info("pubsub processloop shutting down")
return
@ -347,8 +358,8 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
}
msg := &Message{pmsg}
subs := p.getSubscriptions(msg)
p.pushMsg(subs, rpc.from, msg)
vals := p.getValidators(msg)
p.pushMsg(vals, rpc.from, msg)
}
return nil
@ -359,20 +370,9 @@ func msgID(pmsg *pb.Message) string {
return string(pmsg.GetFrom()) + string(pmsg.GetSeqno())
}
// pushMsg pushes a message to a number of subscriptions, performing validation
// as necessary
func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) {
// we perform validation if _any_ of the subscriptions has a validator
// because the message is sent once for all topics
needval := false
for _, sub := range subs {
if sub.validate != nil {
needval = true
break
}
}
if needval {
// pushMsg pushes a message performing validation as necessary
func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) {
if len(vals) > 0 {
// validation is asynchronous and globally throttled with the throttleValidate semaphore.
// the purpose of the global throttle is to bound the goncurrency possible from incoming
// network traffic; each subscription also has an individual throttle to preclude
@ -380,7 +380,7 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) {
select {
case p.validateThrottle <- struct{}{}:
go func() {
p.validate(subs, src, msg)
p.validate(vals, src, msg)
<-p.validateThrottle
}()
default:
@ -393,31 +393,27 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) {
}
// validate performs validation and only sends the message if all validators succeed
func (p *PubSub) validate(subs []*Subscription, src peer.ID, msg *Message) {
func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) {
ctx, cancel := context.WithCancel(p.ctx)
defer cancel()
rch := make(chan bool, len(subs))
rch := make(chan bool, len(vals))
rcount := 0
throttle := false
loop:
for _, sub := range subs {
if sub.validate == nil {
continue
}
for _, val := range vals {
rcount++
select {
case sub.validateThrottle <- struct{}{}:
go func(sub *Subscription) {
rch <- sub.validateMsg(ctx, msg)
<-sub.validateThrottle
}(sub)
case val.validateThrottle <- struct{}{}:
go func(val *topicVal) {
rch <- val.validateMsg(ctx, msg)
<-val.validateThrottle
}(val)
default:
log.Debugf("validation throttled for topic %s", sub.topic)
log.Debugf("validation throttled for topic %s", val.topic)
throttle = true
break loop
}
@ -494,22 +490,20 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error {
return nil
}
// getSubscriptions returns all subscriptions the would receive the given message.
func (p *PubSub) getSubscriptions(msg *Message) []*Subscription {
var subs []*Subscription
// getValidators returns all validators that apply to a given message
func (p *PubSub) getValidators(msg *Message) []*topicVal {
var vals []*topicVal
for _, topic := range msg.GetTopicIDs() {
tSubs, ok := p.myTopics[topic]
val, ok := p.topicVals[topic]
if !ok {
continue
}
for sub := range tSubs {
subs = append(subs, sub)
}
vals = append(vals, val)
}
return subs
return vals
}
type addSubReq struct {
@ -517,31 +511,7 @@ type addSubReq struct {
resp chan *Subscription
}
type SubOpt func(*Subscription) error
type Validator func(context.Context, *Message) bool
// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further.
func WithValidator(validate Validator) SubOpt {
return func(sub *Subscription) error {
sub.validate = validate
return nil
}
}
// WithValidatorTimeout is an option that can be supplied to Subscribe. The argument is a duration after which long-running validators are canceled.
func WithValidatorTimeout(timeout time.Duration) SubOpt {
return func(sub *Subscription) error {
sub.validateTimeout = timeout
return nil
}
}
func WithValidatorConcurrency(n int) SubOpt {
return func(sub *Subscription) error {
sub.validateThrottle = make(chan struct{}, n)
return nil
}
}
type SubOpt func(sub *Subscription) error
// Subscribe returns a new Subscription for the given topic
func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) {
@ -561,8 +531,7 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO
}
sub := &Subscription{
topic: td.GetName(),
validateTimeout: defaultValidateTimeout,
topic: td.GetName(),
}
for _, opt := range opts {
@ -572,10 +541,6 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO
}
}
if sub.validate != nil && sub.validateThrottle == nil {
sub.validateThrottle = make(chan struct{}, defaultValidateConcurrency)
}
out := make(chan *Subscription, 1)
p.addSub <- &addSubReq{
sub: sub,
@ -633,3 +598,100 @@ func (p *PubSub) ListPeers(topic string) []peer.ID {
}
return <-out
}
// per topic validators
type addValReq struct {
topic string
validate Validator
timeout time.Duration
throttle int
resp chan error
}
type topicVal struct {
topic string
validate Validator
validateTimeout time.Duration
validateThrottle chan struct{}
}
// Validator is a function that validates a message
type Validator func(context.Context, *Message) bool
// ValidatorOpt is an option for RegisterTopicValidator
type ValidatorOpt func(addVal *addValReq) error
// WithValidatorTimeout is an option that sets the topic validator timeout
func WithValidatorTimeout(timeout time.Duration) ValidatorOpt {
return func(addVal *addValReq) error {
addVal.timeout = timeout
return nil
}
}
// WithValidatorConcurrency is an option that sets topic validator throttle
func WithValidatorConcurrency(n int) ValidatorOpt {
return func(addVal *addValReq) error {
addVal.throttle = n
return nil
}
}
// RegisterTopicValidator registers a validator for topic
func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...ValidatorOpt) error {
addVal := &addValReq{
topic: topic,
validate: val,
resp: make(chan error, 1),
}
for _, opt := range opts {
err := opt(addVal)
if err != nil {
return err
}
}
p.addVal <- addVal
return <-addVal.resp
}
func (ps *PubSub) addValidator(req *addValReq) {
topic := req.topic
_, ok := ps.topicVals[topic]
if ok {
req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic)
return
}
val := &topicVal{
topic: topic,
validate: req.validate,
validateTimeout: defaultValidateTimeout,
validateThrottle: make(chan struct{}, defaultValidateConcurrency),
}
if req.timeout > 0 {
val.validateTimeout = req.timeout
}
if req.throttle > 0 {
val.validateThrottle = make(chan struct{}, req.throttle)
}
ps.topicVals[topic] = val
req.resp <- nil
}
func (val *topicVal) validateMsg(ctx context.Context, msg *Message) bool {
vctx, cancel := context.WithTimeout(ctx, val.validateTimeout)
defer cancel()
valid := val.validate(vctx, msg)
if !valid {
log.Debugf("validation failed for topic %s", val.topic)
}
return valid
}

View File

@ -351,9 +351,14 @@ func TestValidate(t *testing.T) {
connect(t, hosts[0], hosts[1])
topic := "foobar"
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, msg *Message) bool {
return !bytes.Contains(msg.Data, []byte("illegal"))
}))
})
if err != nil {
t.Fatal(err)
}
sub, err := psubs[1].Subscribe(topic)
if err != nil {
t.Fatal(err)
}
@ -442,17 +447,22 @@ func TestValidateOverload(t *testing.T) {
block := make(chan struct{})
sub, err := psubs[1].Subscribe(topic,
WithValidatorConcurrency(tc.maxConcurrency),
WithValidator(func(ctx context.Context, msg *Message) bool {
err := psubs[1].RegisterTopicValidator(topic,
func(ctx context.Context, msg *Message) bool {
<-block
return true
}))
},
WithValidatorConcurrency(tc.maxConcurrency))
if err != nil {
t.Fatal(err)
}
sub, err := psubs[1].Subscribe(topic)
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 50)
if len(tc.msgs) != tc.maxConcurrency+1 {

View File

@ -2,7 +2,6 @@ package floodsub
import (
"context"
"time"
)
type Subscription struct {
@ -10,10 +9,6 @@ type Subscription struct {
ch chan *Message
cancelCh chan<- *Subscription
err error
validate Validator
validateTimeout time.Duration
validateThrottle chan struct{}
}
func (sub *Subscription) Topic() string {
@ -36,15 +31,3 @@ func (sub *Subscription) Next(ctx context.Context) (*Message, error) {
func (sub *Subscription) Cancel() {
sub.cancelCh <- sub
}
func (sub *Subscription) validateMsg(ctx context.Context, msg *Message) bool {
vctx, cancel := context.WithTimeout(ctx, sub.validateTimeout)
defer cancel()
valid := sub.validate(vctx, msg)
if !valid {
log.Debugf("validation failed for topic %s", sub.topic)
}
return valid
}