mirror of
https://github.com/logos-messaging/go-libp2p-pubsub.git
synced 2026-04-19 08:43:22 +00:00
implement per topic validators
This commit is contained in:
parent
fceb00d234
commit
bbdec3fda2
204
floodsub.go
204
floodsub.go
@ -62,6 +62,12 @@ type PubSub struct {
|
|||||||
// sendMsg handles messages that have been validated
|
// sendMsg handles messages that have been validated
|
||||||
sendMsg chan *sendReq
|
sendMsg chan *sendReq
|
||||||
|
|
||||||
|
// addVal handles validator registration requests
|
||||||
|
addVal chan *addValReq
|
||||||
|
|
||||||
|
// topicVals tracks per topic validators
|
||||||
|
topicVals map[string]*topicVal
|
||||||
|
|
||||||
// validateThrottle limits the number of active validation goroutines
|
// validateThrottle limits the number of active validation goroutines
|
||||||
validateThrottle chan struct{}
|
validateThrottle chan struct{}
|
||||||
|
|
||||||
@ -105,10 +111,12 @@ func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, err
|
|||||||
addSub: make(chan *addSubReq),
|
addSub: make(chan *addSubReq),
|
||||||
getTopics: make(chan *topicReq),
|
getTopics: make(chan *topicReq),
|
||||||
sendMsg: make(chan *sendReq, 32),
|
sendMsg: make(chan *sendReq, 32),
|
||||||
|
addVal: make(chan *addValReq),
|
||||||
validateThrottle: make(chan struct{}, defaultValidateThrottle),
|
validateThrottle: make(chan struct{}, defaultValidateThrottle),
|
||||||
myTopics: make(map[string]map[*Subscription]struct{}),
|
myTopics: make(map[string]map[*Subscription]struct{}),
|
||||||
topics: make(map[string]map[peer.ID]struct{}),
|
topics: make(map[string]map[peer.ID]struct{}),
|
||||||
peers: make(map[peer.ID]chan *RPC),
|
peers: make(map[peer.ID]chan *RPC),
|
||||||
|
topicVals: make(map[string]*topicVal),
|
||||||
seenMessages: timecache.NewTimeCache(time.Second * 30),
|
seenMessages: timecache.NewTimeCache(time.Second * 30),
|
||||||
counter: uint64(time.Now().UnixNano()),
|
counter: uint64(time.Now().UnixNano()),
|
||||||
}
|
}
|
||||||
@ -205,12 +213,15 @@ func (p *PubSub) processLoop(ctx context.Context) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
case msg := <-p.publish:
|
case msg := <-p.publish:
|
||||||
subs := p.getSubscriptions(msg)
|
vals := p.getValidators(msg)
|
||||||
p.pushMsg(subs, p.host.ID(), msg)
|
p.pushMsg(vals, p.host.ID(), msg)
|
||||||
|
|
||||||
case req := <-p.sendMsg:
|
case req := <-p.sendMsg:
|
||||||
p.maybePublishMessage(req.from, req.msg.Message)
|
p.maybePublishMessage(req.from, req.msg.Message)
|
||||||
|
|
||||||
|
case req := <-p.addVal:
|
||||||
|
p.addValidator(req)
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Info("pubsub processloop shutting down")
|
log.Info("pubsub processloop shutting down")
|
||||||
return
|
return
|
||||||
@ -347,8 +358,8 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
msg := &Message{pmsg}
|
msg := &Message{pmsg}
|
||||||
subs := p.getSubscriptions(msg)
|
vals := p.getValidators(msg)
|
||||||
p.pushMsg(subs, rpc.from, msg)
|
p.pushMsg(vals, rpc.from, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -359,20 +370,9 @@ func msgID(pmsg *pb.Message) string {
|
|||||||
return string(pmsg.GetFrom()) + string(pmsg.GetSeqno())
|
return string(pmsg.GetFrom()) + string(pmsg.GetSeqno())
|
||||||
}
|
}
|
||||||
|
|
||||||
// pushMsg pushes a message to a number of subscriptions, performing validation
|
// pushMsg pushes a message performing validation as necessary
|
||||||
// as necessary
|
func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) {
|
||||||
func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) {
|
if len(vals) > 0 {
|
||||||
// we perform validation if _any_ of the subscriptions has a validator
|
|
||||||
// because the message is sent once for all topics
|
|
||||||
needval := false
|
|
||||||
for _, sub := range subs {
|
|
||||||
if sub.validate != nil {
|
|
||||||
needval = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if needval {
|
|
||||||
// validation is asynchronous and globally throttled with the throttleValidate semaphore.
|
// validation is asynchronous and globally throttled with the throttleValidate semaphore.
|
||||||
// the purpose of the global throttle is to bound the goncurrency possible from incoming
|
// the purpose of the global throttle is to bound the goncurrency possible from incoming
|
||||||
// network traffic; each subscription also has an individual throttle to preclude
|
// network traffic; each subscription also has an individual throttle to preclude
|
||||||
@ -380,7 +380,7 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) {
|
|||||||
select {
|
select {
|
||||||
case p.validateThrottle <- struct{}{}:
|
case p.validateThrottle <- struct{}{}:
|
||||||
go func() {
|
go func() {
|
||||||
p.validate(subs, src, msg)
|
p.validate(vals, src, msg)
|
||||||
<-p.validateThrottle
|
<-p.validateThrottle
|
||||||
}()
|
}()
|
||||||
default:
|
default:
|
||||||
@ -393,31 +393,27 @@ func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// validate performs validation and only sends the message if all validators succeed
|
// validate performs validation and only sends the message if all validators succeed
|
||||||
func (p *PubSub) validate(subs []*Subscription, src peer.ID, msg *Message) {
|
func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) {
|
||||||
ctx, cancel := context.WithCancel(p.ctx)
|
ctx, cancel := context.WithCancel(p.ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
rch := make(chan bool, len(subs))
|
rch := make(chan bool, len(vals))
|
||||||
rcount := 0
|
rcount := 0
|
||||||
throttle := false
|
throttle := false
|
||||||
|
|
||||||
loop:
|
loop:
|
||||||
for _, sub := range subs {
|
for _, val := range vals {
|
||||||
if sub.validate == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
rcount++
|
rcount++
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case sub.validateThrottle <- struct{}{}:
|
case val.validateThrottle <- struct{}{}:
|
||||||
go func(sub *Subscription) {
|
go func(val *topicVal) {
|
||||||
rch <- sub.validateMsg(ctx, msg)
|
rch <- val.validateMsg(ctx, msg)
|
||||||
<-sub.validateThrottle
|
<-val.validateThrottle
|
||||||
}(sub)
|
}(val)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
log.Debugf("validation throttled for topic %s", sub.topic)
|
log.Debugf("validation throttled for topic %s", val.topic)
|
||||||
throttle = true
|
throttle = true
|
||||||
break loop
|
break loop
|
||||||
}
|
}
|
||||||
@ -494,22 +490,20 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getSubscriptions returns all subscriptions the would receive the given message.
|
// getValidators returns all validators that apply to a given message
|
||||||
func (p *PubSub) getSubscriptions(msg *Message) []*Subscription {
|
func (p *PubSub) getValidators(msg *Message) []*topicVal {
|
||||||
var subs []*Subscription
|
var vals []*topicVal
|
||||||
|
|
||||||
for _, topic := range msg.GetTopicIDs() {
|
for _, topic := range msg.GetTopicIDs() {
|
||||||
tSubs, ok := p.myTopics[topic]
|
val, ok := p.topicVals[topic]
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for sub := range tSubs {
|
vals = append(vals, val)
|
||||||
subs = append(subs, sub)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return subs
|
return vals
|
||||||
}
|
}
|
||||||
|
|
||||||
type addSubReq struct {
|
type addSubReq struct {
|
||||||
@ -517,31 +511,7 @@ type addSubReq struct {
|
|||||||
resp chan *Subscription
|
resp chan *Subscription
|
||||||
}
|
}
|
||||||
|
|
||||||
type SubOpt func(*Subscription) error
|
type SubOpt func(sub *Subscription) error
|
||||||
type Validator func(context.Context, *Message) bool
|
|
||||||
|
|
||||||
// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further.
|
|
||||||
func WithValidator(validate Validator) SubOpt {
|
|
||||||
return func(sub *Subscription) error {
|
|
||||||
sub.validate = validate
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithValidatorTimeout is an option that can be supplied to Subscribe. The argument is a duration after which long-running validators are canceled.
|
|
||||||
func WithValidatorTimeout(timeout time.Duration) SubOpt {
|
|
||||||
return func(sub *Subscription) error {
|
|
||||||
sub.validateTimeout = timeout
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithValidatorConcurrency(n int) SubOpt {
|
|
||||||
return func(sub *Subscription) error {
|
|
||||||
sub.validateThrottle = make(chan struct{}, n)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Subscribe returns a new Subscription for the given topic
|
// Subscribe returns a new Subscription for the given topic
|
||||||
func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) {
|
func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) {
|
||||||
@ -561,8 +531,7 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO
|
|||||||
}
|
}
|
||||||
|
|
||||||
sub := &Subscription{
|
sub := &Subscription{
|
||||||
topic: td.GetName(),
|
topic: td.GetName(),
|
||||||
validateTimeout: defaultValidateTimeout,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
@ -572,10 +541,6 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sub.validate != nil && sub.validateThrottle == nil {
|
|
||||||
sub.validateThrottle = make(chan struct{}, defaultValidateConcurrency)
|
|
||||||
}
|
|
||||||
|
|
||||||
out := make(chan *Subscription, 1)
|
out := make(chan *Subscription, 1)
|
||||||
p.addSub <- &addSubReq{
|
p.addSub <- &addSubReq{
|
||||||
sub: sub,
|
sub: sub,
|
||||||
@ -633,3 +598,100 @@ func (p *PubSub) ListPeers(topic string) []peer.ID {
|
|||||||
}
|
}
|
||||||
return <-out
|
return <-out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// per topic validators
|
||||||
|
type addValReq struct {
|
||||||
|
topic string
|
||||||
|
validate Validator
|
||||||
|
timeout time.Duration
|
||||||
|
throttle int
|
||||||
|
resp chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
type topicVal struct {
|
||||||
|
topic string
|
||||||
|
validate Validator
|
||||||
|
validateTimeout time.Duration
|
||||||
|
validateThrottle chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validator is a function that validates a message
|
||||||
|
type Validator func(context.Context, *Message) bool
|
||||||
|
|
||||||
|
// ValidatorOpt is an option for RegisterTopicValidator
|
||||||
|
type ValidatorOpt func(addVal *addValReq) error
|
||||||
|
|
||||||
|
// WithValidatorTimeout is an option that sets the topic validator timeout
|
||||||
|
func WithValidatorTimeout(timeout time.Duration) ValidatorOpt {
|
||||||
|
return func(addVal *addValReq) error {
|
||||||
|
addVal.timeout = timeout
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithValidatorConcurrency is an option that sets topic validator throttle
|
||||||
|
func WithValidatorConcurrency(n int) ValidatorOpt {
|
||||||
|
return func(addVal *addValReq) error {
|
||||||
|
addVal.throttle = n
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterTopicValidator registers a validator for topic
|
||||||
|
func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...ValidatorOpt) error {
|
||||||
|
addVal := &addValReq{
|
||||||
|
topic: topic,
|
||||||
|
validate: val,
|
||||||
|
resp: make(chan error, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
err := opt(addVal)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.addVal <- addVal
|
||||||
|
return <-addVal.resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *PubSub) addValidator(req *addValReq) {
|
||||||
|
topic := req.topic
|
||||||
|
|
||||||
|
_, ok := ps.topicVals[topic]
|
||||||
|
if ok {
|
||||||
|
req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
val := &topicVal{
|
||||||
|
topic: topic,
|
||||||
|
validate: req.validate,
|
||||||
|
validateTimeout: defaultValidateTimeout,
|
||||||
|
validateThrottle: make(chan struct{}, defaultValidateConcurrency),
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.timeout > 0 {
|
||||||
|
val.validateTimeout = req.timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.throttle > 0 {
|
||||||
|
val.validateThrottle = make(chan struct{}, req.throttle)
|
||||||
|
}
|
||||||
|
|
||||||
|
ps.topicVals[topic] = val
|
||||||
|
req.resp <- nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (val *topicVal) validateMsg(ctx context.Context, msg *Message) bool {
|
||||||
|
vctx, cancel := context.WithTimeout(ctx, val.validateTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
valid := val.validate(vctx, msg)
|
||||||
|
if !valid {
|
||||||
|
log.Debugf("validation failed for topic %s", val.topic)
|
||||||
|
}
|
||||||
|
|
||||||
|
return valid
|
||||||
|
}
|
||||||
|
|||||||
@ -351,9 +351,14 @@ func TestValidate(t *testing.T) {
|
|||||||
connect(t, hosts[0], hosts[1])
|
connect(t, hosts[0], hosts[1])
|
||||||
topic := "foobar"
|
topic := "foobar"
|
||||||
|
|
||||||
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
|
err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, msg *Message) bool {
|
||||||
return !bytes.Contains(msg.Data, []byte("illegal"))
|
return !bytes.Contains(msg.Data, []byte("illegal"))
|
||||||
}))
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := psubs[1].Subscribe(topic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -442,17 +447,22 @@ func TestValidateOverload(t *testing.T) {
|
|||||||
|
|
||||||
block := make(chan struct{})
|
block := make(chan struct{})
|
||||||
|
|
||||||
sub, err := psubs[1].Subscribe(topic,
|
err := psubs[1].RegisterTopicValidator(topic,
|
||||||
WithValidatorConcurrency(tc.maxConcurrency),
|
func(ctx context.Context, msg *Message) bool {
|
||||||
WithValidator(func(ctx context.Context, msg *Message) bool {
|
|
||||||
<-block
|
<-block
|
||||||
return true
|
return true
|
||||||
}))
|
},
|
||||||
|
WithValidatorConcurrency(tc.maxConcurrency))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sub, err := psubs[1].Subscribe(topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
time.Sleep(time.Millisecond * 50)
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
|
||||||
if len(tc.msgs) != tc.maxConcurrency+1 {
|
if len(tc.msgs) != tc.maxConcurrency+1 {
|
||||||
|
|||||||
@ -2,7 +2,6 @@ package floodsub
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Subscription struct {
|
type Subscription struct {
|
||||||
@ -10,10 +9,6 @@ type Subscription struct {
|
|||||||
ch chan *Message
|
ch chan *Message
|
||||||
cancelCh chan<- *Subscription
|
cancelCh chan<- *Subscription
|
||||||
err error
|
err error
|
||||||
|
|
||||||
validate Validator
|
|
||||||
validateTimeout time.Duration
|
|
||||||
validateThrottle chan struct{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sub *Subscription) Topic() string {
|
func (sub *Subscription) Topic() string {
|
||||||
@ -36,15 +31,3 @@ func (sub *Subscription) Next(ctx context.Context) (*Message, error) {
|
|||||||
func (sub *Subscription) Cancel() {
|
func (sub *Subscription) Cancel() {
|
||||||
sub.cancelCh <- sub
|
sub.cancelCh <- sub
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sub *Subscription) validateMsg(ctx context.Context, msg *Message) bool {
|
|
||||||
vctx, cancel := context.WithTimeout(ctx, sub.validateTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
valid := sub.validate(vctx, msg)
|
|
||||||
if !valid {
|
|
||||||
log.Debugf("validation failed for topic %s", sub.topic)
|
|
||||||
}
|
|
||||||
|
|
||||||
return valid
|
|
||||||
}
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user