diff --git a/pubsub.go b/pubsub.go index c8a0e91..9875a3c 100644 --- a/pubsub.go +++ b/pubsub.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "math/rand" + "sync" "sync/atomic" "time" @@ -19,12 +20,6 @@ import ( timecache "github.com/whyrusleeping/timecache" ) -const ( - defaultValidateTimeout = 150 * time.Millisecond - defaultValidateConcurrency = 100 - defaultValidateThrottle = 8192 -) - var ( TimeCacheDuration = 120 * time.Second ) @@ -44,6 +39,8 @@ type PubSub struct { rt PubSubRouter + val *validation + // incoming messages from other peers incoming chan *RPC @@ -89,12 +86,6 @@ type PubSub struct { // rmVal handles validator unregistration requests rmVal chan *rmValReq - // topicVals tracks per topic validators - topicVals map[string]*topicVal - - // validateThrottle limits the number of active validation goroutines - validateThrottle chan struct{} - // eval thunk in event loop eval chan func() @@ -102,8 +93,10 @@ type PubSub struct { blacklist Blacklist blacklistPeer chan peer.ID - peers map[peer.ID]chan *RPC - seenMessages *timecache.TimeCache + peers map[peer.ID]chan *RPC + + seenMessagesMx sync.Mutex + seenMessages *timecache.TimeCache // key for signing messages; nil when signing is disabled (default for now) signKey crypto.PrivKey @@ -159,35 +152,34 @@ type Option func(*PubSub) error // NewPubSub returns a new PubSub management object. func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option) (*PubSub, error) { ps := &PubSub{ - host: h, - ctx: ctx, - rt: rt, - signID: h.ID(), - signKey: h.Peerstore().PrivKey(h.ID()), - signStrict: true, - incoming: make(chan *RPC, 32), - publish: make(chan *Message), - newPeers: make(chan peer.ID), - newPeerStream: make(chan inet.Stream), - newPeerError: make(chan peer.ID), - peerDead: make(chan peer.ID), - cancelCh: make(chan *Subscription), - getPeers: make(chan *listPeerReq), - addSub: make(chan *addSubReq), - getTopics: make(chan *topicReq), - sendMsg: make(chan *sendReq, 32), - addVal: make(chan *addValReq), - rmVal: make(chan *rmValReq), - validateThrottle: make(chan struct{}, defaultValidateThrottle), - eval: make(chan func()), - myTopics: make(map[string]map[*Subscription]struct{}), - topics: make(map[string]map[peer.ID]struct{}), - peers: make(map[peer.ID]chan *RPC), - topicVals: make(map[string]*topicVal), - blacklist: NewMapBlacklist(), - blacklistPeer: make(chan peer.ID), - seenMessages: timecache.NewTimeCache(TimeCacheDuration), - counter: uint64(time.Now().UnixNano()), + host: h, + ctx: ctx, + rt: rt, + val: newValidation(), + signID: h.ID(), + signKey: h.Peerstore().PrivKey(h.ID()), + signStrict: true, + incoming: make(chan *RPC, 32), + publish: make(chan *Message), + newPeers: make(chan peer.ID), + newPeerStream: make(chan inet.Stream), + newPeerError: make(chan peer.ID), + peerDead: make(chan peer.ID), + cancelCh: make(chan *Subscription), + getPeers: make(chan *listPeerReq), + addSub: make(chan *addSubReq), + getTopics: make(chan *topicReq), + sendMsg: make(chan *sendReq, 32), + addVal: make(chan *addValReq), + rmVal: make(chan *rmValReq), + eval: make(chan func()), + myTopics: make(map[string]map[*Subscription]struct{}), + topics: make(map[string]map[peer.ID]struct{}), + peers: make(map[peer.ID]chan *RPC), + blacklist: NewMapBlacklist(), + blacklistPeer: make(chan peer.ID), + seenMessages: timecache.NewTimeCache(TimeCacheDuration), + counter: uint64(time.Now().UnixNano()), } for _, opt := range opts { @@ -208,20 +200,13 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option } h.Network().Notify((*PubSubNotif)(ps)) + ps.val.Start(ps) + go ps.processLoop(ctx) return ps, nil } -// WithValidateThrottle sets the upper bound on the number of active validation -// goroutines. -func WithValidateThrottle(n int) Option { - return func(ps *PubSub) error { - ps.validateThrottle = make(chan struct{}, n) - return nil - } -} - // WithMessageSigning enables or disables message signing (enabled by default). func WithMessageSigning(enabled bool) Option { return func(p *PubSub) error { @@ -384,17 +369,16 @@ func (p *PubSub) processLoop(ctx context.Context) { p.handleIncomingRPC(rpc) case msg := <-p.publish: - vals := p.getValidators(msg) - p.pushMsg(vals, p.host.ID(), msg) + p.pushMsg(p.host.ID(), msg) case req := <-p.sendMsg: p.publishMessage(req.from, req.msg.Message) case req := <-p.addVal: - p.addValidator(req) + p.val.AddValidator(req) case req := <-p.rmVal: - p.rmValidator(req) + p.val.RemoveValidator(req) case thunk := <-p.eval: thunk() @@ -542,12 +526,22 @@ func (p *PubSub) notifySubs(msg *pb.Message) { // seenMessage returns whether we already saw this message before func (p *PubSub) seenMessage(id string) bool { + p.seenMessagesMx.Lock() + defer p.seenMessagesMx.Unlock() return p.seenMessages.Has(id) } // markSeen marks a message as seen such that seenMessage returns `true' for the given id -func (p *PubSub) markSeen(id string) { +// returns true if the message was freshly marked +func (p *PubSub) markSeen(id string) bool { + p.seenMessagesMx.Lock() + defer p.seenMessagesMx.Unlock() + if p.seenMessages.Has(id) { + return false + } + p.seenMessages.Add(id) + return true } // subscribedToMessage returns whether we are subscribed to one of the topics @@ -592,8 +586,7 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) { } msg := &Message{pmsg} - vals := p.getValidators(msg) - p.pushMsg(vals, rpc.from, msg) + p.pushMsg(rpc.from, msg) } p.rt.HandleRPC(rpc) @@ -605,7 +598,7 @@ func msgID(pmsg *pb.Message) string { } // pushMsg pushes a message performing validation as necessary -func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) { +func (p *PubSub) pushMsg(src peer.ID, msg *Message) { // reject messages from blacklisted peers if p.blacklist.Contains(src) { log.Warningf("dropping message from blacklisted peer %s", src) @@ -630,148 +623,20 @@ func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) { return } - if len(vals) > 0 || msg.Signature != nil { - // validation is asynchronous and globally throttled with the throttleValidate semaphore. - // the purpose of the global throttle is to bound the goncurrency possible from incoming - // network traffic; each validator also has an individual throttle to preclude - // slow (or faulty) validators from starving other topics; see validate below. - select { - case p.validateThrottle <- struct{}{}: - go func() { - p.validate(vals, src, msg) - <-p.validateThrottle - }() - default: - log.Warningf("message validation throttled; dropping message from %s", src) - } + if !p.val.Push(src, msg) { return } - p.publishMessage(src, msg.Message) -} - -// validate performs validation and only sends the message if all validators succeed -func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) { - if msg.Signature != nil { - if !p.validateSignature(msg) { - log.Warningf("message signature validation failed; dropping message from %s", src) - return - } - } - - if len(vals) > 0 { - if !p.validateTopic(vals, src, msg) { - log.Warningf("message validation failed; dropping message from %s", src) - return - } - } - - // all validators were successful, send the message - p.sendMsg <- &sendReq{ - from: src, - msg: msg, - } -} - -func (p *PubSub) validateSignature(msg *Message) bool { - err := verifyMessageSignature(msg.Message) - if err != nil { - log.Debugf("signature verification error: %s", err.Error()) - return false - } - - return true -} - -func (p *PubSub) validateTopic(vals []*topicVal, src peer.ID, msg *Message) bool { - if len(vals) == 1 { - return p.validateSingleTopic(vals[0], src, msg) - } - - ctx, cancel := context.WithCancel(p.ctx) - defer cancel() - - rch := make(chan bool, len(vals)) - rcount := 0 - throttle := false - -loop: - for _, val := range vals { - rcount++ - - select { - case val.validateThrottle <- struct{}{}: - go func(val *topicVal) { - rch <- val.validateMsg(ctx, src, msg) - <-val.validateThrottle - }(val) - - default: - log.Debugf("validation throttled for topic %s", val.topic) - throttle = true - break loop - } - } - - if throttle { - return false - } - - for i := 0; i < rcount; i++ { - valid := <-rch - if !valid { - return false - } - } - - return true -} - -// fast path for single topic validation that avoids the extra goroutine -func (p *PubSub) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) bool { - select { - case val.validateThrottle <- struct{}{}: - ctx, cancel := context.WithCancel(p.ctx) - defer cancel() - - res := val.validateMsg(ctx, src, msg) - <-val.validateThrottle - - return res - - default: - log.Debugf("validation throttled for topic %s", val.topic) - return false + if p.markSeen(id) { + p.publishMessage(src, msg.Message) } } func (p *PubSub) publishMessage(from peer.ID, pmsg *pb.Message) { - id := msgID(pmsg) - if p.seenMessage(id) { - return - } - p.markSeen(id) - p.notifySubs(pmsg) p.rt.Publish(from, pmsg) } -// getValidators returns all validators that apply to a given message -func (p *PubSub) getValidators(msg *Message) []*topicVal { - var vals []*topicVal - - for _, topic := range msg.GetTopicIDs() { - val, ok := p.topicVals[topic] - if !ok { - continue - } - - vals = append(vals, val) - } - - return vals -} - type addSubReq struct { sub *Subscription resp chan *Subscription @@ -883,50 +748,10 @@ func (p *PubSub) BlacklistPeer(pid peer.ID) { p.blacklistPeer <- pid } -// per topic validators -type addValReq struct { - topic string - validate Validator - timeout time.Duration - throttle int - resp chan error -} - -type rmValReq struct { - topic string - resp chan error -} - -type topicVal struct { - topic string - validate Validator - validateTimeout time.Duration - validateThrottle chan struct{} -} - -// Validator is a function that validates a message. -type Validator func(context.Context, peer.ID, *Message) bool - -// ValidatorOpt is an option for RegisterTopicValidator. -type ValidatorOpt func(addVal *addValReq) error - -// WithValidatorTimeout is an option that sets the topic validator timeout. -func WithValidatorTimeout(timeout time.Duration) ValidatorOpt { - return func(addVal *addValReq) error { - addVal.timeout = timeout - return nil - } -} - -// WithValidatorConcurrency is an option that sets topic validator throttle. -func WithValidatorConcurrency(n int) ValidatorOpt { - return func(addVal *addValReq) error { - addVal.throttle = n - return nil - } -} - // RegisterTopicValidator registers a validator for topic. +// By default validators are asynchronous, which means they will run in a separate goroutine. +// The number of active goroutines is controlled by global and per topic validator +// throttles; if it exceeds the throttle threshold, messages will be dropped. func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...ValidatorOpt) error { addVal := &addValReq{ topic: topic, @@ -945,34 +770,6 @@ func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...Val return <-addVal.resp } -func (ps *PubSub) addValidator(req *addValReq) { - topic := req.topic - - _, ok := ps.topicVals[topic] - if ok { - req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic) - return - } - - val := &topicVal{ - topic: topic, - validate: req.validate, - validateTimeout: defaultValidateTimeout, - validateThrottle: make(chan struct{}, defaultValidateConcurrency), - } - - if req.timeout > 0 { - val.validateTimeout = req.timeout - } - - if req.throttle > 0 { - val.validateThrottle = make(chan struct{}, req.throttle) - } - - ps.topicVals[topic] = val - req.resp <- nil -} - // UnregisterTopicValidator removes a validator from a topic. // Returns an error if there was no validator registered with the topic. func (p *PubSub) UnregisterTopicValidator(topic string) error { @@ -984,27 +781,3 @@ func (p *PubSub) UnregisterTopicValidator(topic string) error { p.rmVal <- rmVal return <-rmVal.resp } - -func (ps *PubSub) rmValidator(req *rmValReq) { - topic := req.topic - - _, ok := ps.topicVals[topic] - if ok { - delete(ps.topicVals, topic) - req.resp <- nil - } else { - req.resp <- fmt.Errorf("No validator for topic %s", topic) - } -} - -func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) bool { - vctx, cancel := context.WithTimeout(ctx, val.validateTimeout) - defer cancel() - - valid := val.validate(vctx, src, msg) - if !valid { - log.Debugf("validation failed for topic %s", val.topic) - } - - return valid -} diff --git a/validation.go b/validation.go new file mode 100644 index 0000000..023ee86 --- /dev/null +++ b/validation.go @@ -0,0 +1,391 @@ +package pubsub + +import ( + "context" + "fmt" + "runtime" + "time" + + peer "github.com/libp2p/go-libp2p-peer" +) + +const ( + defaultValidateConcurrency = 1024 + defaultValidateThrottle = 8192 +) + +// Validator is a function that validates a message. +type Validator func(context.Context, peer.ID, *Message) bool + +// ValidatorOpt is an option for RegisterTopicValidator. +type ValidatorOpt func(addVal *addValReq) error + +// validation represents the validator pipeline. +// The validator pipeline performs signature validation and runs a +// sequence of user-configured validators per-topic. It is possible to +// adjust various concurrency parameters, such as the number of +// workers and the max number of simultaneous validations. The user +// can also attach inline validators that will be executed +// synchronously; this may be useful to prevent superfluous +// context-switching for lightweight tasks. +type validation struct { + p *PubSub + + // topicVals tracks per topic validators + topicVals map[string]*topicVal + + // validateQ is the front-end to the validation pipeline + validateQ chan *validateReq + + // validateThrottle limits the number of active validation goroutines + validateThrottle chan struct{} + + // this is the number of synchronous validation workers + validateWorkers int +} + +// validation requests +type validateReq struct { + vals []*topicVal + src peer.ID + msg *Message +} + +// representation of topic validators +type topicVal struct { + topic string + validate Validator + validateTimeout time.Duration + validateThrottle chan struct{} + validateInline bool +} + +// async request to add a topic validators +type addValReq struct { + topic string + validate Validator + timeout time.Duration + throttle int + inline bool + resp chan error +} + +// async request to remove a topic validator +type rmValReq struct { + topic string + resp chan error +} + +// newValidation creates a new validation pipeline +func newValidation() *validation { + return &validation{ + topicVals: make(map[string]*topicVal), + validateQ: make(chan *validateReq, 32), + validateThrottle: make(chan struct{}, defaultValidateThrottle), + validateWorkers: runtime.NumCPU(), + } +} + +// Start attaches the validation pipeline to a pubsub instance and starts background +// workers +func (v *validation) Start(p *PubSub) { + v.p = p + for i := 0; i < v.validateWorkers; i++ { + go v.validateWorker() + } +} + +// AddValidator adds a new validator +func (v *validation) AddValidator(req *addValReq) { + topic := req.topic + + _, ok := v.topicVals[topic] + if ok { + req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic) + return + } + + val := &topicVal{ + topic: topic, + validate: req.validate, + validateTimeout: 0, + validateThrottle: make(chan struct{}, defaultValidateConcurrency), + validateInline: req.inline, + } + + if req.timeout > 0 { + val.validateTimeout = req.timeout + } + + if req.throttle > 0 { + val.validateThrottle = make(chan struct{}, req.throttle) + } + + v.topicVals[topic] = val + req.resp <- nil +} + +// RemoveValidator removes an existing validator +func (v *validation) RemoveValidator(req *rmValReq) { + topic := req.topic + + _, ok := v.topicVals[topic] + if ok { + delete(v.topicVals, topic) + req.resp <- nil + } else { + req.resp <- fmt.Errorf("No validator for topic %s", topic) + } +} + +// Push pushes a message into the validation pipeline. +// It returns true if the message can be forwarded immediately without validation. +func (v *validation) Push(src peer.ID, msg *Message) bool { + vals := v.getValidators(msg) + + if len(vals) > 0 || msg.Signature != nil { + select { + case v.validateQ <- &validateReq{vals, src, msg}: + default: + log.Warningf("message validation throttled; dropping message from %s", src) + } + return false + } + + return true +} + +// getValidators returns all validators that apply to a given message +func (v *validation) getValidators(msg *Message) []*topicVal { + var vals []*topicVal + + for _, topic := range msg.GetTopicIDs() { + val, ok := v.topicVals[topic] + if !ok { + continue + } + + vals = append(vals, val) + } + + return vals +} + +// validateWorker is an active goroutine performing inline validation +func (v *validation) validateWorker() { + for { + select { + case req := <-v.validateQ: + v.validate(req.vals, req.src, req.msg) + case <-v.p.ctx.Done(): + return + } + } +} + +// validate performs validation and only sends the message if all validators succeed +// signature validation is performed synchronously, while user validators are invoked +// asynchronously, throttled by the global validation throttle. +func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message) { + if msg.Signature != nil { + if !v.validateSignature(msg) { + log.Warningf("message signature validation failed; dropping message from %s", src) + return + } + } + + // we can mark the message as seen now that we have verified the signature + // and avoid invoking user validators more than once + id := msgID(msg.Message) + if !v.p.markSeen(id) { + return + } + + var inline, async []*topicVal + for _, val := range vals { + if val.validateInline { + inline = append(inline, val) + } else { + async = append(async, val) + } + } + + // apply inline (synchronous) validators + for _, val := range inline { + if !val.validateMsg(v.p.ctx, src, msg) { + log.Debugf("message validation failed; dropping message from %s", src) + return + } + } + + // apply async validators + if len(async) > 0 { + select { + case v.validateThrottle <- struct{}{}: + go func() { + v.doValidateTopic(async, src, msg) + <-v.validateThrottle + }() + default: + log.Warningf("message validation throttled; dropping message from %s", src) + } + return + } + + // no async validators, send the message + v.p.sendMsg <- &sendReq{ + from: src, + msg: msg, + } +} + +func (v *validation) validateSignature(msg *Message) bool { + err := verifyMessageSignature(msg.Message) + if err != nil { + log.Debugf("signature verification error: %s", err.Error()) + return false + } + + return true +} + +func (v *validation) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message) { + if !v.validateTopic(vals, src, msg) { + log.Warningf("message validation failed; dropping message from %s", src) + return + } + + v.p.sendMsg <- &sendReq{ + from: src, + msg: msg, + } +} + +func (v *validation) validateTopic(vals []*topicVal, src peer.ID, msg *Message) bool { + if len(vals) == 1 { + return v.validateSingleTopic(vals[0], src, msg) + } + + ctx, cancel := context.WithCancel(v.p.ctx) + defer cancel() + + rch := make(chan bool, len(vals)) + rcount := 0 + throttle := false + +loop: + for _, val := range vals { + rcount++ + + select { + case val.validateThrottle <- struct{}{}: + go func(val *topicVal) { + rch <- val.validateMsg(ctx, src, msg) + <-val.validateThrottle + }(val) + + default: + log.Debugf("validation throttled for topic %s", val.topic) + throttle = true + break loop + } + } + + if throttle { + return false + } + + for i := 0; i < rcount; i++ { + valid := <-rch + if !valid { + return false + } + } + + return true +} + +// fast path for single topic validation that avoids the extra goroutine +func (v *validation) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) bool { + select { + case val.validateThrottle <- struct{}{}: + res := val.validateMsg(v.p.ctx, src, msg) + <-val.validateThrottle + + return res + + default: + log.Debugf("validation throttled for topic %s", val.topic) + return false + } +} + +func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) bool { + if val.validateTimeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, val.validateTimeout) + defer cancel() + } + + valid := val.validate(ctx, src, msg) + if !valid { + log.Debugf("validation failed for topic %s", val.topic) + } + + return valid +} + +/// Options + +// WithValidateThrottle sets the upper bound on the number of active validation +// goroutines across all topics. The default is 8192. +func WithValidateThrottle(n int) Option { + return func(ps *PubSub) error { + ps.val.validateThrottle = make(chan struct{}, n) + return nil + } +} + +// WithValidateWorkers sets the number of synchronous validation worker goroutines. +// Defaults to NumCPU. +// +// The synchronous validation workers perform signature validation, apply inline +// user validators, and schedule asynchronous user validators. +// You can adjust this parameter to devote less cpu time to synchronous validation. +func WithValidateWorkers(n int) Option { + return func(ps *PubSub) error { + if n > 0 { + ps.val.validateWorkers = n + return nil + } + return fmt.Errorf("number of validation workers must be > 0") + } +} + +// WithValidatorTimeout is an option that sets a timeout for an (asynchronous) topic validator. +// By default there is no timeout in asynchronous validators. +func WithValidatorTimeout(timeout time.Duration) ValidatorOpt { + return func(addVal *addValReq) error { + addVal.timeout = timeout + return nil + } +} + +// WithValidatorConcurrency is an option that sets the topic validator throttle. +// This controls the number of active validation goroutines for the topic; the default is 1024. +func WithValidatorConcurrency(n int) ValidatorOpt { + return func(addVal *addValReq) error { + addVal.throttle = n + return nil + } +} + +// WithValidatorInline is an option that sets the validation disposition to synchronous: +// it will be executed inline in validation front-end, without spawning a new goroutine. +// This is suitable for simple or cpu-bound validators that do not block. +func WithValidatorInline(inline bool) ValidatorOpt { + return func(addVal *addValReq) error { + addVal.inline = inline + return nil + } +}