mirror of
https://github.com/logos-messaging/go-libp2p-pubsub.git
synced 2026-01-02 12:53:09 +00:00
split off validation into its own type
This commit is contained in:
parent
8d0c8d60b1
commit
cb423f474d
412
pubsub.go
412
pubsub.go
@ -5,7 +5,6 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -21,11 +20,6 @@ import (
|
||||
timecache "github.com/whyrusleeping/timecache"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultValidateConcurrency = 1024
|
||||
defaultValidateThrottle = 8192
|
||||
)
|
||||
|
||||
var (
|
||||
TimeCacheDuration = 120 * time.Second
|
||||
)
|
||||
@ -45,6 +39,8 @@ type PubSub struct {
|
||||
|
||||
rt PubSubRouter
|
||||
|
||||
val *validation
|
||||
|
||||
// incoming messages from other peers
|
||||
incoming chan *RPC
|
||||
|
||||
@ -90,18 +86,6 @@ type PubSub struct {
|
||||
// rmVal handles validator unregistration requests
|
||||
rmVal chan *rmValReq
|
||||
|
||||
// topicVals tracks per topic validators
|
||||
topicVals map[string]*topicVal
|
||||
|
||||
// validateQ is the front-end to the validation pipeline
|
||||
validateQ chan *validateReq
|
||||
|
||||
// validateThrottle limits the number of active validation goroutines
|
||||
validateThrottle chan struct{}
|
||||
|
||||
// this is the number of synchronous validation workers
|
||||
validateWorkers int
|
||||
|
||||
// eval thunk in event loop
|
||||
eval chan func()
|
||||
|
||||
@ -168,37 +152,34 @@ type Option func(*PubSub) error
|
||||
// NewPubSub returns a new PubSub management object.
|
||||
func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option) (*PubSub, error) {
|
||||
ps := &PubSub{
|
||||
host: h,
|
||||
ctx: ctx,
|
||||
rt: rt,
|
||||
signID: h.ID(),
|
||||
signKey: h.Peerstore().PrivKey(h.ID()),
|
||||
signStrict: true,
|
||||
incoming: make(chan *RPC, 32),
|
||||
publish: make(chan *Message),
|
||||
newPeers: make(chan peer.ID),
|
||||
newPeerStream: make(chan inet.Stream),
|
||||
newPeerError: make(chan peer.ID),
|
||||
peerDead: make(chan peer.ID),
|
||||
cancelCh: make(chan *Subscription),
|
||||
getPeers: make(chan *listPeerReq),
|
||||
addSub: make(chan *addSubReq),
|
||||
getTopics: make(chan *topicReq),
|
||||
sendMsg: make(chan *sendReq, 32),
|
||||
addVal: make(chan *addValReq),
|
||||
rmVal: make(chan *rmValReq),
|
||||
validateThrottle: make(chan struct{}, defaultValidateThrottle),
|
||||
eval: make(chan func()),
|
||||
myTopics: make(map[string]map[*Subscription]struct{}),
|
||||
topics: make(map[string]map[peer.ID]struct{}),
|
||||
peers: make(map[peer.ID]chan *RPC),
|
||||
topicVals: make(map[string]*topicVal),
|
||||
validateQ: make(chan *validateReq, 32),
|
||||
blacklist: NewMapBlacklist(),
|
||||
blacklistPeer: make(chan peer.ID),
|
||||
seenMessages: timecache.NewTimeCache(TimeCacheDuration),
|
||||
counter: uint64(time.Now().UnixNano()),
|
||||
validateWorkers: runtime.NumCPU(),
|
||||
host: h,
|
||||
ctx: ctx,
|
||||
rt: rt,
|
||||
val: newValidation(),
|
||||
signID: h.ID(),
|
||||
signKey: h.Peerstore().PrivKey(h.ID()),
|
||||
signStrict: true,
|
||||
incoming: make(chan *RPC, 32),
|
||||
publish: make(chan *Message),
|
||||
newPeers: make(chan peer.ID),
|
||||
newPeerStream: make(chan inet.Stream),
|
||||
newPeerError: make(chan peer.ID),
|
||||
peerDead: make(chan peer.ID),
|
||||
cancelCh: make(chan *Subscription),
|
||||
getPeers: make(chan *listPeerReq),
|
||||
addSub: make(chan *addSubReq),
|
||||
getTopics: make(chan *topicReq),
|
||||
sendMsg: make(chan *sendReq, 32),
|
||||
addVal: make(chan *addValReq),
|
||||
rmVal: make(chan *rmValReq),
|
||||
eval: make(chan func()),
|
||||
myTopics: make(map[string]map[*Subscription]struct{}),
|
||||
topics: make(map[string]map[peer.ID]struct{}),
|
||||
peers: make(map[peer.ID]chan *RPC),
|
||||
blacklist: NewMapBlacklist(),
|
||||
blacklistPeer: make(chan peer.ID),
|
||||
seenMessages: timecache.NewTimeCache(TimeCacheDuration),
|
||||
counter: uint64(time.Now().UnixNano()),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@ -219,36 +200,13 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option
|
||||
}
|
||||
h.Network().Notify((*PubSubNotif)(ps))
|
||||
|
||||
ps.val.Start(ps)
|
||||
|
||||
go ps.processLoop(ctx)
|
||||
|
||||
for i := 0; i < ps.validateWorkers; i++ {
|
||||
go ps.validateWorker()
|
||||
}
|
||||
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
// WithValidateThrottle sets the upper bound on the number of active validation
|
||||
// goroutines across all topics. The default is 8192.
|
||||
func WithValidateThrottle(n int) Option {
|
||||
return func(ps *PubSub) error {
|
||||
ps.validateThrottle = make(chan struct{}, n)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithValidateWorkers sets the number of synchronous validation worker goroutines.
|
||||
// Defaults to NumCPU.
|
||||
func WithValidateWorkers(n int) Option {
|
||||
return func(ps *PubSub) error {
|
||||
if n > 0 {
|
||||
ps.validateWorkers = n
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("number of validation workers must be > 0")
|
||||
}
|
||||
}
|
||||
|
||||
// WithMessageSigning enables or disables message signing (enabled by default).
|
||||
func WithMessageSigning(enabled bool) Option {
|
||||
return func(p *PubSub) error {
|
||||
@ -411,17 +369,16 @@ func (p *PubSub) processLoop(ctx context.Context) {
|
||||
p.handleIncomingRPC(rpc)
|
||||
|
||||
case msg := <-p.publish:
|
||||
vals := p.getValidators(msg)
|
||||
p.pushMsg(vals, p.host.ID(), msg)
|
||||
p.pushMsg(p.host.ID(), msg)
|
||||
|
||||
case req := <-p.sendMsg:
|
||||
p.publishMessage(req.from, req.msg.Message)
|
||||
|
||||
case req := <-p.addVal:
|
||||
p.addValidator(req)
|
||||
p.val.AddValidator(req)
|
||||
|
||||
case req := <-p.rmVal:
|
||||
p.rmValidator(req)
|
||||
p.val.RemoveValidator(req)
|
||||
|
||||
case thunk := <-p.eval:
|
||||
thunk()
|
||||
@ -629,8 +586,7 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) {
|
||||
}
|
||||
|
||||
msg := &Message{pmsg}
|
||||
vals := p.getValidators(msg)
|
||||
p.pushMsg(vals, rpc.from, msg)
|
||||
p.pushMsg(rpc.from, msg)
|
||||
}
|
||||
|
||||
p.rt.HandleRPC(rpc)
|
||||
@ -642,7 +598,7 @@ func msgID(pmsg *pb.Message) string {
|
||||
}
|
||||
|
||||
// pushMsg pushes a message performing validation as necessary
|
||||
func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) {
|
||||
func (p *PubSub) pushMsg(src peer.ID, msg *Message) {
|
||||
// reject messages from blacklisted peers
|
||||
if p.blacklist.Contains(src) {
|
||||
log.Warningf("dropping message from blacklisted peer %s", src)
|
||||
@ -667,12 +623,7 @@ func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(vals) > 0 || msg.Signature != nil {
|
||||
select {
|
||||
case p.validateQ <- &validateReq{vals, src, msg}:
|
||||
default:
|
||||
log.Warningf("message validation throttled; dropping message from %s", src)
|
||||
}
|
||||
if !p.val.Push(src, msg) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -681,178 +632,11 @@ func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSub) validateWorker() {
|
||||
for {
|
||||
select {
|
||||
case req := <-p.validateQ:
|
||||
p.validate(req.vals, req.src, req.msg)
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validate performs validation and only sends the message if all validators succeed
|
||||
// signature validation is performed synchronously, while user validators are invoked
|
||||
// asynchronously, throttled by the global validation throttle.
|
||||
func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) {
|
||||
if msg.Signature != nil {
|
||||
if !p.validateSignature(msg) {
|
||||
log.Warningf("message signature validation failed; dropping message from %s", src)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// we can mark the message as seen now that we have verified the signature
|
||||
// and avoid invoking user validators more than once
|
||||
id := msgID(msg.Message)
|
||||
if !p.markSeen(id) {
|
||||
return
|
||||
}
|
||||
|
||||
var inline, async []*topicVal
|
||||
for _, val := range vals {
|
||||
if val.validateInline {
|
||||
inline = append(inline, val)
|
||||
} else {
|
||||
async = append(async, val)
|
||||
}
|
||||
}
|
||||
|
||||
// apply inline (synchronous) validators
|
||||
for _, val := range inline {
|
||||
if !val.validateMsg(p.ctx, src, msg) {
|
||||
log.Debugf("message validation failed; dropping message from %s", src)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// apply async validators
|
||||
if len(async) > 0 {
|
||||
select {
|
||||
case p.validateThrottle <- struct{}{}:
|
||||
go func() {
|
||||
p.doValidateTopic(async, src, msg)
|
||||
<-p.validateThrottle
|
||||
}()
|
||||
default:
|
||||
log.Warningf("message validation throttled; dropping message from %s", src)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// no async validators, send the message
|
||||
p.sendMsg <- &sendReq{
|
||||
from: src,
|
||||
msg: msg,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSub) validateSignature(msg *Message) bool {
|
||||
err := verifyMessageSignature(msg.Message)
|
||||
if err != nil {
|
||||
log.Debugf("signature verification error: %s", err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *PubSub) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message) {
|
||||
if !p.validateTopic(vals, src, msg) {
|
||||
log.Warningf("message validation failed; dropping message from %s", src)
|
||||
return
|
||||
}
|
||||
|
||||
p.sendMsg <- &sendReq{
|
||||
from: src,
|
||||
msg: msg,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSub) validateTopic(vals []*topicVal, src peer.ID, msg *Message) bool {
|
||||
if len(vals) == 1 {
|
||||
return p.validateSingleTopic(vals[0], src, msg)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(p.ctx)
|
||||
defer cancel()
|
||||
|
||||
rch := make(chan bool, len(vals))
|
||||
rcount := 0
|
||||
throttle := false
|
||||
|
||||
loop:
|
||||
for _, val := range vals {
|
||||
rcount++
|
||||
|
||||
select {
|
||||
case val.validateThrottle <- struct{}{}:
|
||||
go func(val *topicVal) {
|
||||
rch <- val.validateMsg(ctx, src, msg)
|
||||
<-val.validateThrottle
|
||||
}(val)
|
||||
|
||||
default:
|
||||
log.Debugf("validation throttled for topic %s", val.topic)
|
||||
throttle = true
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
if throttle {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < rcount; i++ {
|
||||
valid := <-rch
|
||||
if !valid {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// fast path for single topic validation that avoids the extra goroutine
|
||||
func (p *PubSub) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) bool {
|
||||
select {
|
||||
case val.validateThrottle <- struct{}{}:
|
||||
ctx, cancel := context.WithCancel(p.ctx)
|
||||
defer cancel()
|
||||
|
||||
res := val.validateMsg(ctx, src, msg)
|
||||
<-val.validateThrottle
|
||||
|
||||
return res
|
||||
|
||||
default:
|
||||
log.Debugf("validation throttled for topic %s", val.topic)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSub) publishMessage(from peer.ID, pmsg *pb.Message) {
|
||||
p.notifySubs(pmsg)
|
||||
p.rt.Publish(from, pmsg)
|
||||
}
|
||||
|
||||
// getValidators returns all validators that apply to a given message
|
||||
func (p *PubSub) getValidators(msg *Message) []*topicVal {
|
||||
var vals []*topicVal
|
||||
|
||||
for _, topic := range msg.GetTopicIDs() {
|
||||
val, ok := p.topicVals[topic]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
vals = append(vals, val)
|
||||
}
|
||||
|
||||
return vals
|
||||
}
|
||||
|
||||
type addSubReq struct {
|
||||
sub *Subscription
|
||||
resp chan *Subscription
|
||||
@ -964,70 +748,6 @@ func (p *PubSub) BlacklistPeer(pid peer.ID) {
|
||||
p.blacklistPeer <- pid
|
||||
}
|
||||
|
||||
// validation requests
|
||||
type validateReq struct {
|
||||
vals []*topicVal
|
||||
src peer.ID
|
||||
msg *Message
|
||||
}
|
||||
|
||||
// per topic validators
|
||||
type addValReq struct {
|
||||
topic string
|
||||
validate Validator
|
||||
timeout time.Duration
|
||||
throttle int
|
||||
inline bool
|
||||
resp chan error
|
||||
}
|
||||
|
||||
type rmValReq struct {
|
||||
topic string
|
||||
resp chan error
|
||||
}
|
||||
|
||||
type topicVal struct {
|
||||
topic string
|
||||
validate Validator
|
||||
validateTimeout time.Duration
|
||||
validateThrottle chan struct{}
|
||||
validateInline bool
|
||||
}
|
||||
|
||||
// Validator is a function that validates a message.
|
||||
type Validator func(context.Context, peer.ID, *Message) bool
|
||||
|
||||
// ValidatorOpt is an option for RegisterTopicValidator.
|
||||
type ValidatorOpt func(addVal *addValReq) error
|
||||
|
||||
// WithValidatorTimeout is an option that sets a timeout for an (asynchronous) topic validator.
|
||||
// By default there is no timeout in asynchronous validators.
|
||||
func WithValidatorTimeout(timeout time.Duration) ValidatorOpt {
|
||||
return func(addVal *addValReq) error {
|
||||
addVal.timeout = timeout
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithValidatorConcurrency is an option that sets the topic validator throttle.
|
||||
// This controls the number of active validation goroutines for the topic; the default is 1024.
|
||||
func WithValidatorConcurrency(n int) ValidatorOpt {
|
||||
return func(addVal *addValReq) error {
|
||||
addVal.throttle = n
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithValidatorInline is an option that sets the validation disposition to synchronous:
|
||||
// it will be executed inline in validation front-end, without spawning a new goroutine.
|
||||
// This is suitable for simple or cpu-bound validators that do not block.
|
||||
func WithValidatorInline(inline bool) ValidatorOpt {
|
||||
return func(addVal *addValReq) error {
|
||||
addVal.inline = inline
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterTopicValidator registers a validator for topic.
|
||||
// By default validators are asynchronous, which means they will run in a separate goroutine.
|
||||
// The number of active goroutines is controlled by global and per topic validator
|
||||
@ -1050,35 +770,6 @@ func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...Val
|
||||
return <-addVal.resp
|
||||
}
|
||||
|
||||
func (ps *PubSub) addValidator(req *addValReq) {
|
||||
topic := req.topic
|
||||
|
||||
_, ok := ps.topicVals[topic]
|
||||
if ok {
|
||||
req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic)
|
||||
return
|
||||
}
|
||||
|
||||
val := &topicVal{
|
||||
topic: topic,
|
||||
validate: req.validate,
|
||||
validateTimeout: 0,
|
||||
validateThrottle: make(chan struct{}, defaultValidateConcurrency),
|
||||
validateInline: req.inline,
|
||||
}
|
||||
|
||||
if req.timeout > 0 {
|
||||
val.validateTimeout = req.timeout
|
||||
}
|
||||
|
||||
if req.throttle > 0 {
|
||||
val.validateThrottle = make(chan struct{}, req.throttle)
|
||||
}
|
||||
|
||||
ps.topicVals[topic] = val
|
||||
req.resp <- nil
|
||||
}
|
||||
|
||||
// UnregisterTopicValidator removes a validator from a topic.
|
||||
// Returns an error if there was no validator registered with the topic.
|
||||
func (p *PubSub) UnregisterTopicValidator(topic string) error {
|
||||
@ -1090,30 +781,3 @@ func (p *PubSub) UnregisterTopicValidator(topic string) error {
|
||||
p.rmVal <- rmVal
|
||||
return <-rmVal.resp
|
||||
}
|
||||
|
||||
func (ps *PubSub) rmValidator(req *rmValReq) {
|
||||
topic := req.topic
|
||||
|
||||
_, ok := ps.topicVals[topic]
|
||||
if ok {
|
||||
delete(ps.topicVals, topic)
|
||||
req.resp <- nil
|
||||
} else {
|
||||
req.resp <- fmt.Errorf("No validator for topic %s", topic)
|
||||
}
|
||||
}
|
||||
|
||||
func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) bool {
|
||||
if val.validateTimeout > 0 {
|
||||
var cancel func()
|
||||
ctx, cancel = context.WithTimeout(ctx, val.validateTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
valid := val.validate(ctx, src, msg)
|
||||
if !valid {
|
||||
log.Debugf("validation failed for topic %s", val.topic)
|
||||
}
|
||||
|
||||
return valid
|
||||
}
|
||||
|
||||
383
validation.go
Normal file
383
validation.go
Normal file
@ -0,0 +1,383 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
peer "github.com/libp2p/go-libp2p-peer"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultValidateConcurrency = 1024
|
||||
defaultValidateThrottle = 8192
|
||||
)
|
||||
|
||||
// Validator is a function that validates a message.
|
||||
type Validator func(context.Context, peer.ID, *Message) bool
|
||||
|
||||
// ValidatorOpt is an option for RegisterTopicValidator.
|
||||
type ValidatorOpt func(addVal *addValReq) error
|
||||
|
||||
// validation represents the validator pipeline
|
||||
type validation struct {
|
||||
p *PubSub
|
||||
|
||||
// topicVals tracks per topic validators
|
||||
topicVals map[string]*topicVal
|
||||
|
||||
// validateQ is the front-end to the validation pipeline
|
||||
validateQ chan *validateReq
|
||||
|
||||
// validateThrottle limits the number of active validation goroutines
|
||||
validateThrottle chan struct{}
|
||||
|
||||
// this is the number of synchronous validation workers
|
||||
validateWorkers int
|
||||
}
|
||||
|
||||
// validation requests
|
||||
type validateReq struct {
|
||||
vals []*topicVal
|
||||
src peer.ID
|
||||
msg *Message
|
||||
}
|
||||
|
||||
// representation of topic validators
|
||||
type topicVal struct {
|
||||
topic string
|
||||
validate Validator
|
||||
validateTimeout time.Duration
|
||||
validateThrottle chan struct{}
|
||||
validateInline bool
|
||||
}
|
||||
|
||||
// async request to add a topic validators
|
||||
type addValReq struct {
|
||||
topic string
|
||||
validate Validator
|
||||
timeout time.Duration
|
||||
throttle int
|
||||
inline bool
|
||||
resp chan error
|
||||
}
|
||||
|
||||
// async request to remove a topic validator
|
||||
type rmValReq struct {
|
||||
topic string
|
||||
resp chan error
|
||||
}
|
||||
|
||||
// newValidation creates a new validation pipeline
|
||||
func newValidation() *validation {
|
||||
return &validation{
|
||||
topicVals: make(map[string]*topicVal),
|
||||
validateQ: make(chan *validateReq, 32),
|
||||
validateThrottle: make(chan struct{}, defaultValidateThrottle),
|
||||
validateWorkers: runtime.NumCPU(),
|
||||
}
|
||||
}
|
||||
|
||||
// Start attaches the validation pipeline to a pubsub instance and starts background
|
||||
// workers
|
||||
func (v *validation) Start(p *PubSub) {
|
||||
v.p = p
|
||||
for i := 0; i < v.validateWorkers; i++ {
|
||||
go v.validateWorker()
|
||||
}
|
||||
}
|
||||
|
||||
// AddValidator adds a new validator
|
||||
func (v *validation) AddValidator(req *addValReq) {
|
||||
topic := req.topic
|
||||
|
||||
_, ok := v.topicVals[topic]
|
||||
if ok {
|
||||
req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic)
|
||||
return
|
||||
}
|
||||
|
||||
val := &topicVal{
|
||||
topic: topic,
|
||||
validate: req.validate,
|
||||
validateTimeout: 0,
|
||||
validateThrottle: make(chan struct{}, defaultValidateConcurrency),
|
||||
validateInline: req.inline,
|
||||
}
|
||||
|
||||
if req.timeout > 0 {
|
||||
val.validateTimeout = req.timeout
|
||||
}
|
||||
|
||||
if req.throttle > 0 {
|
||||
val.validateThrottle = make(chan struct{}, req.throttle)
|
||||
}
|
||||
|
||||
v.topicVals[topic] = val
|
||||
req.resp <- nil
|
||||
}
|
||||
|
||||
// RemoveValidator removes an existing validator
|
||||
func (v *validation) RemoveValidator(req *rmValReq) {
|
||||
topic := req.topic
|
||||
|
||||
_, ok := v.topicVals[topic]
|
||||
if ok {
|
||||
delete(v.topicVals, topic)
|
||||
req.resp <- nil
|
||||
} else {
|
||||
req.resp <- fmt.Errorf("No validator for topic %s", topic)
|
||||
}
|
||||
}
|
||||
|
||||
// Push pushes a message into the validation pipeline.
|
||||
// It returns true if the message can be forwarded immediately without validation.
|
||||
func (v *validation) Push(src peer.ID, msg *Message) bool {
|
||||
vals := v.getValidators(msg)
|
||||
|
||||
if len(vals) > 0 || msg.Signature != nil {
|
||||
select {
|
||||
case v.validateQ <- &validateReq{vals, src, msg}:
|
||||
default:
|
||||
log.Warningf("message validation throttled; dropping message from %s", src)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// getValidators returns all validators that apply to a given message
|
||||
func (v *validation) getValidators(msg *Message) []*topicVal {
|
||||
var vals []*topicVal
|
||||
|
||||
for _, topic := range msg.GetTopicIDs() {
|
||||
val, ok := v.topicVals[topic]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
vals = append(vals, val)
|
||||
}
|
||||
|
||||
return vals
|
||||
}
|
||||
|
||||
// validateWorker is an active goroutine performing inline validation
|
||||
func (v *validation) validateWorker() {
|
||||
for {
|
||||
select {
|
||||
case req := <-v.validateQ:
|
||||
v.validate(req.vals, req.src, req.msg)
|
||||
case <-v.p.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validate performs validation and only sends the message if all validators succeed
|
||||
// signature validation is performed synchronously, while user validators are invoked
|
||||
// asynchronously, throttled by the global validation throttle.
|
||||
func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message) {
|
||||
if msg.Signature != nil {
|
||||
if !v.validateSignature(msg) {
|
||||
log.Warningf("message signature validation failed; dropping message from %s", src)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// we can mark the message as seen now that we have verified the signature
|
||||
// and avoid invoking user validators more than once
|
||||
id := msgID(msg.Message)
|
||||
if !v.p.markSeen(id) {
|
||||
return
|
||||
}
|
||||
|
||||
var inline, async []*topicVal
|
||||
for _, val := range vals {
|
||||
if val.validateInline {
|
||||
inline = append(inline, val)
|
||||
} else {
|
||||
async = append(async, val)
|
||||
}
|
||||
}
|
||||
|
||||
// apply inline (synchronous) validators
|
||||
for _, val := range inline {
|
||||
if !val.validateMsg(v.p.ctx, src, msg) {
|
||||
log.Debugf("message validation failed; dropping message from %s", src)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// apply async validators
|
||||
if len(async) > 0 {
|
||||
select {
|
||||
case v.validateThrottle <- struct{}{}:
|
||||
go func() {
|
||||
v.doValidateTopic(async, src, msg)
|
||||
<-v.validateThrottle
|
||||
}()
|
||||
default:
|
||||
log.Warningf("message validation throttled; dropping message from %s", src)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// no async validators, send the message
|
||||
v.p.sendMsg <- &sendReq{
|
||||
from: src,
|
||||
msg: msg,
|
||||
}
|
||||
}
|
||||
|
||||
func (v *validation) validateSignature(msg *Message) bool {
|
||||
err := verifyMessageSignature(msg.Message)
|
||||
if err != nil {
|
||||
log.Debugf("signature verification error: %s", err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (v *validation) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message) {
|
||||
if !v.validateTopic(vals, src, msg) {
|
||||
log.Warningf("message validation failed; dropping message from %s", src)
|
||||
return
|
||||
}
|
||||
|
||||
v.p.sendMsg <- &sendReq{
|
||||
from: src,
|
||||
msg: msg,
|
||||
}
|
||||
}
|
||||
|
||||
func (v *validation) validateTopic(vals []*topicVal, src peer.ID, msg *Message) bool {
|
||||
if len(vals) == 1 {
|
||||
return v.validateSingleTopic(vals[0], src, msg)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(v.p.ctx)
|
||||
defer cancel()
|
||||
|
||||
rch := make(chan bool, len(vals))
|
||||
rcount := 0
|
||||
throttle := false
|
||||
|
||||
loop:
|
||||
for _, val := range vals {
|
||||
rcount++
|
||||
|
||||
select {
|
||||
case val.validateThrottle <- struct{}{}:
|
||||
go func(val *topicVal) {
|
||||
rch <- val.validateMsg(ctx, src, msg)
|
||||
<-val.validateThrottle
|
||||
}(val)
|
||||
|
||||
default:
|
||||
log.Debugf("validation throttled for topic %s", val.topic)
|
||||
throttle = true
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
if throttle {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < rcount; i++ {
|
||||
valid := <-rch
|
||||
if !valid {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// fast path for single topic validation that avoids the extra goroutine
|
||||
func (v *validation) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) bool {
|
||||
select {
|
||||
case val.validateThrottle <- struct{}{}:
|
||||
ctx, cancel := context.WithCancel(v.p.ctx)
|
||||
defer cancel()
|
||||
|
||||
res := val.validateMsg(ctx, src, msg)
|
||||
<-val.validateThrottle
|
||||
|
||||
return res
|
||||
|
||||
default:
|
||||
log.Debugf("validation throttled for topic %s", val.topic)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) bool {
|
||||
if val.validateTimeout > 0 {
|
||||
var cancel func()
|
||||
ctx, cancel = context.WithTimeout(ctx, val.validateTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
valid := val.validate(ctx, src, msg)
|
||||
if !valid {
|
||||
log.Debugf("validation failed for topic %s", val.topic)
|
||||
}
|
||||
|
||||
return valid
|
||||
}
|
||||
|
||||
/// Options
|
||||
|
||||
// WithValidateThrottle sets the upper bound on the number of active validation
|
||||
// goroutines across all topics. The default is 8192.
|
||||
func WithValidateThrottle(n int) Option {
|
||||
return func(ps *PubSub) error {
|
||||
ps.val.validateThrottle = make(chan struct{}, n)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithValidateWorkers sets the number of synchronous validation worker goroutines.
|
||||
// Defaults to NumCPU.
|
||||
func WithValidateWorkers(n int) Option {
|
||||
return func(ps *PubSub) error {
|
||||
if n > 0 {
|
||||
ps.val.validateWorkers = n
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("number of validation workers must be > 0")
|
||||
}
|
||||
}
|
||||
|
||||
// WithValidatorTimeout is an option that sets a timeout for an (asynchronous) topic validator.
|
||||
// By default there is no timeout in asynchronous validators.
|
||||
func WithValidatorTimeout(timeout time.Duration) ValidatorOpt {
|
||||
return func(addVal *addValReq) error {
|
||||
addVal.timeout = timeout
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithValidatorConcurrency is an option that sets the topic validator throttle.
|
||||
// This controls the number of active validation goroutines for the topic; the default is 1024.
|
||||
func WithValidatorConcurrency(n int) ValidatorOpt {
|
||||
return func(addVal *addValReq) error {
|
||||
addVal.throttle = n
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithValidatorInline is an option that sets the validation disposition to synchronous:
|
||||
// it will be executed inline in validation front-end, without spawning a new goroutine.
|
||||
// This is suitable for simple or cpu-bound validators that do not block.
|
||||
func WithValidatorInline(inline bool) ValidatorOpt {
|
||||
return func(addVal *addValReq) error {
|
||||
addVal.inline = inline
|
||||
return nil
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user