make maximum concurrency configurable, split loop

This commit is contained in:
keks 2017-12-16 13:12:23 +01:00 committed by vyzo
parent fe09d1eea3
commit 88274db0bb
2 changed files with 129 additions and 71 deletions

View File

@ -19,7 +19,7 @@ import (
const (
ID = protocol.ID("/floodsub/1.0.0")
maxConcurrency = 10
defaultMaxConcurrency = 10
defaultValidateTimeout = 150 * time.Millisecond
)
@ -88,8 +88,17 @@ type RPC struct {
from peer.ID
}
type Option func(*PubSub) error
func WithMaxConcurrency(n int) Option {
return func(ps *PubSub) error {
ps.throttleValidate = make(chan struct{}, n)
return nil
}
}
// NewFloodSub returns a new FloodSub management object
func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) {
ps := &PubSub{
host: h,
ctx: ctx,
@ -110,12 +119,19 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
throttleValidate: make(chan struct{}, maxConcurrency),
}
for _, opt := range opts {
err := opt(ps)
if err != nil {
return nil, err
}
}
h.SetStreamHandler(ID, ps.handleNewStream)
h.Network().Notify((*PubSubNotif)(ps))
go ps.processLoop(ctx)
return ps
return ps, nil
}
// processLoop handles all inputs arriving on the channels
@ -372,14 +388,25 @@ func msgID(pmsg *pb.Message) string {
// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter.
func (p *PubSub) validate(subs []*Subscription, msg *Message) bool {
for _, sub := range subs {
results := make([]chan bool, len(subs))
ctxs := make([]context.Context, len(subs))
for i, sub := range subs {
result := make(chan bool)
ctx, cancel := context.WithTimeout(p.ctx, sub.validateTimeout)
defer cancel()
result := make(chan bool)
ctxs[i] = ctx
results[i] = result
go func(sub *Subscription) {
result <- sub.validate == nil || sub.validate(ctx, msg)
}(sub)
}
for i, sub := range subs {
ctx := ctxs[i]
result := results[i]
select {
case valid := <-result:

View File

@ -81,10 +81,14 @@ func connectAll(t *testing.T, hosts []host.Host) {
}
}
func getPubsubs(ctx context.Context, hs []host.Host) []*PubSub {
func getPubsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSub {
var psubs []*PubSub
for _, h := range hs {
psubs = append(psubs, NewFloodSub(ctx, h))
ps, err := NewFloodSub(ctx, h, opts...)
if err != nil {
panic(err)
}
psubs = append(psubs, ps)
}
return psubs
}
@ -290,11 +294,14 @@ func TestSelfReceive(t *testing.T) {
host := getNetHosts(t, ctx, 1)[0]
psub := NewFloodSub(ctx, host)
psub, err := NewFloodSub(ctx, host)
if err != nil {
t.Fatal(err)
}
msg := []byte("hello world")
err := psub.Publish("foobar", msg)
err = psub.Publish("foobar", msg)
if err != nil {
t.Fatal(err)
}
@ -487,82 +494,103 @@ func TestValidateOverload(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hosts := getNetHosts(t, ctx, 2)
psubs := getPubsubs(ctx, hosts)
connect(t, hosts[0], hosts[1])
topic := "foobar"
block := make(chan struct{})
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
<-block
return true
}))
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 50)
msgs := []struct {
type msg struct {
msg []byte
validates bool
}
tcs := []struct {
msgs []msg
maxConcurrency int
}{
{msg: []byte("this is a legal message"), validates: true},
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
{msg: []byte("there also is nothing controversial about this message"), validates: true},
{msg: []byte("also fine"), validates: true},
{msg: []byte("still, all good"), validates: true},
{msg: []byte("this is getting boring"), validates: true},
{msg: []byte("foo"), validates: true},
{msg: []byte("foobar"), validates: true},
{msg: []byte("foofoo"), validates: true},
{msg: []byte("barfoo"), validates: true},
{msg: []byte("barbar"), validates: false},
{
maxConcurrency: 10,
msgs: []msg{
{msg: []byte("this is a legal message"), validates: true},
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
{msg: []byte("there also is nothing controversial about this message"), validates: true},
{msg: []byte("also fine"), validates: true},
{msg: []byte("still, all good"), validates: true},
{msg: []byte("this is getting boring"), validates: true},
{msg: []byte("foo"), validates: true},
{msg: []byte("foobar"), validates: true},
{msg: []byte("foofoo"), validates: true},
{msg: []byte("barfoo"), validates: true},
{msg: []byte("oh no!"), validates: false},
},
},
{
maxConcurrency: 2,
msgs: []msg{
{msg: []byte("this is a legal message"), validates: true},
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
{msg: []byte("oh no!"), validates: false},
},
},
}
if len(msgs) != maxConcurrency+1 {
t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(msgs), maxConcurrency+1)
}
for _, tc := range tcs {
p := psubs[0]
hosts := getNetHosts(t, ctx, 2)
psubs := getPubsubs(ctx, hosts, WithMaxConcurrency(tc.maxConcurrency))
var wg sync.WaitGroup
wg.Add(1)
go func() {
for _, tc := range msgs {
select {
case msg := <-sub.ch:
if !tc.validates {
t.Log(msg)
t.Error("expected message validation to drop the message because all validator goroutines are taken")
}
case <-time.After(333 * time.Millisecond):
if tc.validates {
t.Error("expected message validation to accept the message")
}
}
}
wg.Done()
}()
connect(t, hosts[0], hosts[1])
topic := "foobar"
for i, tc := range msgs {
err := p.Publish(topic, tc.msg)
block := make(chan struct{})
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
<-block
return true
}))
if err != nil {
t.Fatal(err)
}
// wait a bit to let pubsub's internal state machine start validating the message
time.Sleep(10 * time.Millisecond)
time.Sleep(time.Millisecond * 50)
// unblock validator goroutines after we sent one too many
if i == len(msgs)-1 {
close(block)
if len(tc.msgs) != tc.maxConcurrency+1 {
t.Fatalf("expected number of messages sent to be defaultMaxConcurrency+1. Got %d, expected %d", len(tc.msgs), tc.maxConcurrency+1)
}
}
wg.Wait()
p := psubs[0]
var wg sync.WaitGroup
wg.Add(1)
go func() {
for _, tmsg := range tc.msgs {
select {
case msg := <-sub.ch:
if !tmsg.validates {
t.Log(msg)
t.Error("expected message validation to drop the message because all validator goroutines are taken")
}
case <-time.After(333 * time.Millisecond):
if tmsg.validates {
t.Error("expected message validation to accept the message")
}
}
}
wg.Done()
}()
for i, tmsg := range tc.msgs {
err := p.Publish(topic, tmsg.msg)
if err != nil {
t.Fatal(err)
}
// wait a bit to let pubsub's internal state machine start validating the message
time.Sleep(10 * time.Millisecond)
// unblock validator goroutines after we sent one too many
if i == len(tc.msgs)-1 {
close(block)
}
}
wg.Wait()
}
}
func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {
@ -646,7 +674,10 @@ func TestSubReporting(t *testing.T) {
defer cancel()
host := getNetHosts(t, ctx, 1)[0]
psub := NewFloodSub(ctx, host)
psub, err := NewFloodSub(ctx, host)
if err != nil {
t.Fatal(err)
}
fooSub, err := psub.Subscribe("foo")
if err != nil {