make maximum concurrency configurable, split loop
This commit is contained in:
parent
fe09d1eea3
commit
88274db0bb
37
floodsub.go
37
floodsub.go
|
@ -19,7 +19,7 @@ import (
|
|||
|
||||
const (
|
||||
ID = protocol.ID("/floodsub/1.0.0")
|
||||
maxConcurrency = 10
|
||||
defaultMaxConcurrency = 10
|
||||
defaultValidateTimeout = 150 * time.Millisecond
|
||||
)
|
||||
|
||||
|
@ -88,8 +88,17 @@ type RPC struct {
|
|||
from peer.ID
|
||||
}
|
||||
|
||||
type Option func(*PubSub) error
|
||||
|
||||
func WithMaxConcurrency(n int) Option {
|
||||
return func(ps *PubSub) error {
|
||||
ps.throttleValidate = make(chan struct{}, n)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewFloodSub returns a new FloodSub management object
|
||||
func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
|
||||
func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) {
|
||||
ps := &PubSub{
|
||||
host: h,
|
||||
ctx: ctx,
|
||||
|
@ -110,12 +119,19 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
|
|||
throttleValidate: make(chan struct{}, maxConcurrency),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
err := opt(ps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
h.SetStreamHandler(ID, ps.handleNewStream)
|
||||
h.Network().Notify((*PubSubNotif)(ps))
|
||||
|
||||
go ps.processLoop(ctx)
|
||||
|
||||
return ps
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
// processLoop handles all inputs arriving on the channels
|
||||
|
@ -372,14 +388,25 @@ func msgID(pmsg *pb.Message) string {
|
|||
|
||||
// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter.
|
||||
func (p *PubSub) validate(subs []*Subscription, msg *Message) bool {
|
||||
for _, sub := range subs {
|
||||
results := make([]chan bool, len(subs))
|
||||
ctxs := make([]context.Context, len(subs))
|
||||
|
||||
for i, sub := range subs {
|
||||
result := make(chan bool)
|
||||
ctx, cancel := context.WithTimeout(p.ctx, sub.validateTimeout)
|
||||
defer cancel()
|
||||
|
||||
result := make(chan bool)
|
||||
ctxs[i] = ctx
|
||||
results[i] = result
|
||||
|
||||
go func(sub *Subscription) {
|
||||
result <- sub.validate == nil || sub.validate(ctx, msg)
|
||||
}(sub)
|
||||
}
|
||||
|
||||
for i, sub := range subs {
|
||||
ctx := ctxs[i]
|
||||
result := results[i]
|
||||
|
||||
select {
|
||||
case valid := <-result:
|
||||
|
|
163
floodsub_test.go
163
floodsub_test.go
|
@ -81,10 +81,14 @@ func connectAll(t *testing.T, hosts []host.Host) {
|
|||
}
|
||||
}
|
||||
|
||||
func getPubsubs(ctx context.Context, hs []host.Host) []*PubSub {
|
||||
func getPubsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSub {
|
||||
var psubs []*PubSub
|
||||
for _, h := range hs {
|
||||
psubs = append(psubs, NewFloodSub(ctx, h))
|
||||
ps, err := NewFloodSub(ctx, h, opts...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
psubs = append(psubs, ps)
|
||||
}
|
||||
return psubs
|
||||
}
|
||||
|
@ -290,11 +294,14 @@ func TestSelfReceive(t *testing.T) {
|
|||
|
||||
host := getNetHosts(t, ctx, 1)[0]
|
||||
|
||||
psub := NewFloodSub(ctx, host)
|
||||
psub, err := NewFloodSub(ctx, host)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msg := []byte("hello world")
|
||||
|
||||
err := psub.Publish("foobar", msg)
|
||||
err = psub.Publish("foobar", msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -487,82 +494,103 @@ func TestValidateOverload(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
hosts := getNetHosts(t, ctx, 2)
|
||||
psubs := getPubsubs(ctx, hosts)
|
||||
|
||||
connect(t, hosts[0], hosts[1])
|
||||
topic := "foobar"
|
||||
|
||||
block := make(chan struct{})
|
||||
|
||||
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
|
||||
<-block
|
||||
return true
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
|
||||
msgs := []struct {
|
||||
type msg struct {
|
||||
msg []byte
|
||||
validates bool
|
||||
}
|
||||
|
||||
tcs := []struct {
|
||||
msgs []msg
|
||||
|
||||
maxConcurrency int
|
||||
}{
|
||||
{msg: []byte("this is a legal message"), validates: true},
|
||||
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
|
||||
{msg: []byte("there also is nothing controversial about this message"), validates: true},
|
||||
{msg: []byte("also fine"), validates: true},
|
||||
{msg: []byte("still, all good"), validates: true},
|
||||
{msg: []byte("this is getting boring"), validates: true},
|
||||
{msg: []byte("foo"), validates: true},
|
||||
{msg: []byte("foobar"), validates: true},
|
||||
{msg: []byte("foofoo"), validates: true},
|
||||
{msg: []byte("barfoo"), validates: true},
|
||||
{msg: []byte("barbar"), validates: false},
|
||||
{
|
||||
maxConcurrency: 10,
|
||||
msgs: []msg{
|
||||
{msg: []byte("this is a legal message"), validates: true},
|
||||
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
|
||||
{msg: []byte("there also is nothing controversial about this message"), validates: true},
|
||||
{msg: []byte("also fine"), validates: true},
|
||||
{msg: []byte("still, all good"), validates: true},
|
||||
{msg: []byte("this is getting boring"), validates: true},
|
||||
{msg: []byte("foo"), validates: true},
|
||||
{msg: []byte("foobar"), validates: true},
|
||||
{msg: []byte("foofoo"), validates: true},
|
||||
{msg: []byte("barfoo"), validates: true},
|
||||
{msg: []byte("oh no!"), validates: false},
|
||||
},
|
||||
},
|
||||
{
|
||||
maxConcurrency: 2,
|
||||
msgs: []msg{
|
||||
{msg: []byte("this is a legal message"), validates: true},
|
||||
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
|
||||
{msg: []byte("oh no!"), validates: false},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if len(msgs) != maxConcurrency+1 {
|
||||
t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(msgs), maxConcurrency+1)
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
|
||||
p := psubs[0]
|
||||
hosts := getNetHosts(t, ctx, 2)
|
||||
psubs := getPubsubs(ctx, hosts, WithMaxConcurrency(tc.maxConcurrency))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
for _, tc := range msgs {
|
||||
select {
|
||||
case msg := <-sub.ch:
|
||||
if !tc.validates {
|
||||
t.Log(msg)
|
||||
t.Error("expected message validation to drop the message because all validator goroutines are taken")
|
||||
}
|
||||
case <-time.After(333 * time.Millisecond):
|
||||
if tc.validates {
|
||||
t.Error("expected message validation to accept the message")
|
||||
}
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
connect(t, hosts[0], hosts[1])
|
||||
topic := "foobar"
|
||||
|
||||
for i, tc := range msgs {
|
||||
err := p.Publish(topic, tc.msg)
|
||||
block := make(chan struct{})
|
||||
|
||||
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
|
||||
<-block
|
||||
return true
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// wait a bit to let pubsub's internal state machine start validating the message
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
|
||||
// unblock validator goroutines after we sent one too many
|
||||
if i == len(msgs)-1 {
|
||||
close(block)
|
||||
if len(tc.msgs) != tc.maxConcurrency+1 {
|
||||
t.Fatalf("expected number of messages sent to be defaultMaxConcurrency+1. Got %d, expected %d", len(tc.msgs), tc.maxConcurrency+1)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
p := psubs[0]
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
for _, tmsg := range tc.msgs {
|
||||
select {
|
||||
case msg := <-sub.ch:
|
||||
if !tmsg.validates {
|
||||
t.Log(msg)
|
||||
t.Error("expected message validation to drop the message because all validator goroutines are taken")
|
||||
}
|
||||
case <-time.After(333 * time.Millisecond):
|
||||
if tmsg.validates {
|
||||
t.Error("expected message validation to accept the message")
|
||||
}
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
for i, tmsg := range tc.msgs {
|
||||
err := p.Publish(topic, tmsg.msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// wait a bit to let pubsub's internal state machine start validating the message
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// unblock validator goroutines after we sent one too many
|
||||
if i == len(tc.msgs)-1 {
|
||||
close(block)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {
|
||||
|
@ -646,7 +674,10 @@ func TestSubReporting(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
host := getNetHosts(t, ctx, 1)[0]
|
||||
psub := NewFloodSub(ctx, host)
|
||||
psub, err := NewFloodSub(ctx, host)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fooSub, err := psub.Subscribe("foo")
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue