mirror of
https://github.com/logos-messaging/go-libp2p-pubsub.git
synced 2026-01-11 09:13:12 +00:00
per subscription validation throttle and more efficient dispatch logic
This commit is contained in:
parent
d2f6a0050f
commit
982c4de960
197
floodsub.go
197
floodsub.go
@ -61,9 +61,6 @@ type PubSub struct {
|
||||
// sendMsg handles messages that have been validated
|
||||
sendMsg chan sendReq
|
||||
|
||||
// throttleValidate bounds the number of goroutines concurrently validating messages
|
||||
throttleValidate chan struct{}
|
||||
|
||||
peers map[peer.ID]chan *RPC
|
||||
seenMessages *timecache.TimeCache
|
||||
|
||||
@ -90,33 +87,25 @@ type RPC struct {
|
||||
|
||||
type Option func(*PubSub) error
|
||||
|
||||
func WithMaxConcurrency(n int) Option {
|
||||
return func(ps *PubSub) error {
|
||||
ps.throttleValidate = make(chan struct{}, n)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewFloodSub returns a new FloodSub management object
|
||||
func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) {
|
||||
ps := &PubSub{
|
||||
host: h,
|
||||
ctx: ctx,
|
||||
incoming: make(chan *RPC, 32),
|
||||
publish: make(chan *Message),
|
||||
newPeers: make(chan inet.Stream),
|
||||
peerDead: make(chan peer.ID),
|
||||
cancelCh: make(chan *Subscription),
|
||||
getPeers: make(chan *listPeerReq),
|
||||
addSub: make(chan *addSubReq),
|
||||
getTopics: make(chan *topicReq),
|
||||
sendMsg: make(chan sendReq),
|
||||
myTopics: make(map[string]map[*Subscription]struct{}),
|
||||
topics: make(map[string]map[peer.ID]struct{}),
|
||||
peers: make(map[peer.ID]chan *RPC),
|
||||
seenMessages: timecache.NewTimeCache(time.Second * 30),
|
||||
counter: uint64(time.Now().UnixNano()),
|
||||
throttleValidate: make(chan struct{}, defaultMaxConcurrency),
|
||||
host: h,
|
||||
ctx: ctx,
|
||||
incoming: make(chan *RPC, 32),
|
||||
publish: make(chan *Message),
|
||||
newPeers: make(chan inet.Stream),
|
||||
peerDead: make(chan peer.ID),
|
||||
cancelCh: make(chan *Subscription),
|
||||
getPeers: make(chan *listPeerReq),
|
||||
addSub: make(chan *addSubReq),
|
||||
getTopics: make(chan *topicReq),
|
||||
sendMsg: make(chan sendReq),
|
||||
myTopics: make(map[string]map[*Subscription]struct{}),
|
||||
topics: make(map[string]map[peer.ID]struct{}),
|
||||
peers: make(map[peer.ID]chan *RPC),
|
||||
seenMessages: timecache.NewTimeCache(time.Second * 30),
|
||||
counter: uint64(time.Now().UnixNano()),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@ -204,24 +193,9 @@ func (p *PubSub) processLoop(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
case msg := <-p.publish:
|
||||
subs := p.getSubscriptions(msg) // call before goroutine!
|
||||
subs := p.getSubscriptions(msg)
|
||||
p.pushMsg(subs, p.host.ID(), msg)
|
||||
|
||||
select {
|
||||
case p.throttleValidate <- struct{}{}:
|
||||
go func(msg *Message) {
|
||||
defer func() { <-p.throttleValidate }()
|
||||
|
||||
if p.validate(subs, msg) {
|
||||
p.sendMsg <- sendReq{
|
||||
from: p.host.ID(),
|
||||
msg: msg,
|
||||
}
|
||||
|
||||
}
|
||||
}(msg)
|
||||
default:
|
||||
log.Warning("could not acquire validator; dropping message")
|
||||
}
|
||||
case req := <-p.sendMsg:
|
||||
p.maybePublishMessage(req.from, req.msg.Message)
|
||||
|
||||
@ -360,24 +334,11 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
|
||||
continue
|
||||
}
|
||||
|
||||
subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine!
|
||||
|
||||
select {
|
||||
case p.throttleValidate <- struct{}{}:
|
||||
go func(pmsg *pb.Message) {
|
||||
defer func() { <-p.throttleValidate }()
|
||||
|
||||
if p.validate(subs, &Message{pmsg}) {
|
||||
p.sendMsg <- sendReq{
|
||||
from: rpc.from,
|
||||
msg: &Message{pmsg},
|
||||
}
|
||||
}
|
||||
}(pmsg)
|
||||
default:
|
||||
log.Warning("could not acquire validator; dropping message")
|
||||
}
|
||||
msg := &Message{pmsg}
|
||||
subs := p.getSubscriptions(msg)
|
||||
p.pushMsg(subs, rpc.from, msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -386,41 +347,80 @@ func msgID(pmsg *pb.Message) string {
|
||||
return string(pmsg.GetFrom()) + string(pmsg.GetSeqno())
|
||||
}
|
||||
|
||||
// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter.
|
||||
func (p *PubSub) validate(subs []*Subscription, msg *Message) bool {
|
||||
results := make([]chan bool, len(subs))
|
||||
ctxs := make([]context.Context, len(subs))
|
||||
|
||||
for i, sub := range subs {
|
||||
result := make(chan bool)
|
||||
ctx, cancel := context.WithTimeout(p.ctx, sub.validateTimeout)
|
||||
defer cancel()
|
||||
|
||||
ctxs[i] = ctx
|
||||
results[i] = result
|
||||
|
||||
go func(sub *Subscription) {
|
||||
result <- sub.validate == nil || sub.validate(ctx, msg)
|
||||
}(sub)
|
||||
}
|
||||
|
||||
for i, sub := range subs {
|
||||
ctx := ctxs[i]
|
||||
result := results[i]
|
||||
|
||||
select {
|
||||
case valid := <-result:
|
||||
if !valid {
|
||||
log.Debugf("validator for topic %s returned false", sub.topic)
|
||||
return false
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.Debugf("validator for topic %s timed out. msg: %s", sub.topic, msg)
|
||||
return false
|
||||
// pushMsg pushes a message to a number of subscriptions, performing validation
|
||||
// as necessary
|
||||
func (p *PubSub) pushMsg(subs []*Subscription, src peer.ID, msg *Message) {
|
||||
// we perform validation if _any_ of the subscriptions has a validator
|
||||
// because the message is sent once for all topics
|
||||
needval := false
|
||||
for _, sub := range subs {
|
||||
if sub.validate != nil {
|
||||
needval = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
if !needval {
|
||||
go func() {
|
||||
p.sendMsg <- sendReq{
|
||||
from: src,
|
||||
msg: msg,
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
// validation is asynchronous
|
||||
// XXX vyzo: do we want a global validation throttle here?
|
||||
go p.validate(subs, src, msg)
|
||||
}
|
||||
|
||||
// validate performs validation and only sends the message if all validators succeed
|
||||
func (p *PubSub) validate(subs []*Subscription, src peer.ID, msg *Message) {
|
||||
results := make([]chan bool, 0, len(subs))
|
||||
throttle := false
|
||||
|
||||
loop:
|
||||
for _, sub := range subs {
|
||||
if sub.validate == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
rch := make(chan bool, 1)
|
||||
results = append(results, rch)
|
||||
|
||||
select {
|
||||
case sub.validateThrottle <- struct{}{}:
|
||||
go func(sub *Subscription, msg *Message, rch chan bool) {
|
||||
rch <- sub.validateMsg(p.ctx, msg)
|
||||
<-sub.validateThrottle
|
||||
}(sub, msg, rch)
|
||||
|
||||
default:
|
||||
log.Debugf("validation throttled for topic %s", sub.topic)
|
||||
throttle = true
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
if throttle {
|
||||
log.Warningf("message validation throttled; dropping message from %s", src)
|
||||
return
|
||||
}
|
||||
|
||||
for _, rch := range results {
|
||||
valid := <-rch
|
||||
if !valid {
|
||||
log.Warningf("message validation failed; dropping message from %s", src)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// all validators were successful, send the message
|
||||
p.sendMsg <- sendReq{
|
||||
from: src,
|
||||
msg: msg,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSub) maybePublishMessage(from peer.ID, pmsg *pb.Message) {
|
||||
@ -516,6 +516,13 @@ func WithValidatorTimeout(timeout time.Duration) SubOpt {
|
||||
}
|
||||
}
|
||||
|
||||
func WithMaxConcurrency(n int) SubOpt {
|
||||
return func(sub *Subscription) error {
|
||||
sub.validateThrottle = make(chan struct{}, n)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe returns a new Subscription for the given topic
|
||||
func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) {
|
||||
td := pb.TopicDescriptor{Name: &topic}
|
||||
@ -545,6 +552,10 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO
|
||||
}
|
||||
}
|
||||
|
||||
if sub.validate != nil && sub.validateThrottle == nil {
|
||||
sub.validateThrottle = make(chan struct{}, defaultMaxConcurrency)
|
||||
}
|
||||
|
||||
out := make(chan *Subscription, 1)
|
||||
p.addSub <- &addSubReq{
|
||||
sub: sub,
|
||||
|
||||
@ -533,17 +533,20 @@ func TestValidateOverload(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
|
||||
hosts := getNetHosts(t, ctx, 2)
|
||||
psubs := getPubsubs(ctx, hosts, WithMaxConcurrency(tc.maxConcurrency))
|
||||
psubs := getPubsubs(ctx, hosts)
|
||||
|
||||
connect(t, hosts[0], hosts[1])
|
||||
topic := "foobar"
|
||||
|
||||
block := make(chan struct{})
|
||||
|
||||
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
|
||||
<-block
|
||||
return true
|
||||
}))
|
||||
sub, err := psubs[1].Subscribe(topic,
|
||||
WithMaxConcurrency(tc.maxConcurrency),
|
||||
WithValidator(func(ctx context.Context, msg *Message) bool {
|
||||
<-block
|
||||
return true
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@ -11,8 +11,9 @@ type Subscription struct {
|
||||
cancelCh chan<- *Subscription
|
||||
err error
|
||||
|
||||
validate Validator
|
||||
validateTimeout time.Duration
|
||||
validate Validator
|
||||
validateTimeout time.Duration
|
||||
validateThrottle chan struct{}
|
||||
}
|
||||
|
||||
func (sub *Subscription) Topic() string {
|
||||
@ -35,3 +36,24 @@ func (sub *Subscription) Next(ctx context.Context) (*Message, error) {
|
||||
func (sub *Subscription) Cancel() {
|
||||
sub.cancelCh <- sub
|
||||
}
|
||||
|
||||
func (sub *Subscription) validateMsg(ctx context.Context, msg *Message) bool {
|
||||
result := make(chan bool, 1)
|
||||
vctx, cancel := context.WithTimeout(ctx, sub.validateTimeout)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
result <- sub.validate(vctx, msg)
|
||||
}()
|
||||
|
||||
select {
|
||||
case valid := <-result:
|
||||
if !valid {
|
||||
log.Debugf("validation failed for topic %s", sub.topic)
|
||||
}
|
||||
return valid
|
||||
case <-vctx.Done():
|
||||
log.Debugf("validation timeout for topic %s", sub.topic)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user