From 88274db0bba733a0e8067d7e4933e6a4b1fbfe53 Mon Sep 17 00:00:00 2001 From: keks Date: Sat, 16 Dec 2017 13:12:23 +0100 Subject: [PATCH] make maximum concurrency configurable, split loop --- floodsub.go | 37 +++++++++-- floodsub_test.go | 163 ++++++++++++++++++++++++++++------------------- 2 files changed, 129 insertions(+), 71 deletions(-) diff --git a/floodsub.go b/floodsub.go index f7c2eb0..70a207c 100644 --- a/floodsub.go +++ b/floodsub.go @@ -19,7 +19,7 @@ import ( const ( ID = protocol.ID("/floodsub/1.0.0") - maxConcurrency = 10 + defaultMaxConcurrency = 10 defaultValidateTimeout = 150 * time.Millisecond ) @@ -88,8 +88,17 @@ type RPC struct { from peer.ID } +type Option func(*PubSub) error + +func WithMaxConcurrency(n int) Option { + return func(ps *PubSub) error { + ps.throttleValidate = make(chan struct{}, n) + return nil + } +} + // NewFloodSub returns a new FloodSub management object -func NewFloodSub(ctx context.Context, h host.Host) *PubSub { +func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) { ps := &PubSub{ host: h, ctx: ctx, @@ -110,12 +119,19 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub { throttleValidate: make(chan struct{}, maxConcurrency), } + for _, opt := range opts { + err := opt(ps) + if err != nil { + return nil, err + } + } + h.SetStreamHandler(ID, ps.handleNewStream) h.Network().Notify((*PubSubNotif)(ps)) go ps.processLoop(ctx) - return ps + return ps, nil } // processLoop handles all inputs arriving on the channels @@ -372,14 +388,25 @@ func msgID(pmsg *pb.Message) string { // validate is called in a goroutine and calls the validate functions of all subs with msg as parameter. func (p *PubSub) validate(subs []*Subscription, msg *Message) bool { - for _, sub := range subs { + results := make([]chan bool, len(subs)) + ctxs := make([]context.Context, len(subs)) + + for i, sub := range subs { + result := make(chan bool) ctx, cancel := context.WithTimeout(p.ctx, sub.validateTimeout) defer cancel() - result := make(chan bool) + ctxs[i] = ctx + results[i] = result + go func(sub *Subscription) { result <- sub.validate == nil || sub.validate(ctx, msg) }(sub) + } + + for i, sub := range subs { + ctx := ctxs[i] + result := results[i] select { case valid := <-result: diff --git a/floodsub_test.go b/floodsub_test.go index 35a2400..68357ab 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -81,10 +81,14 @@ func connectAll(t *testing.T, hosts []host.Host) { } } -func getPubsubs(ctx context.Context, hs []host.Host) []*PubSub { +func getPubsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSub { var psubs []*PubSub for _, h := range hs { - psubs = append(psubs, NewFloodSub(ctx, h)) + ps, err := NewFloodSub(ctx, h, opts...) + if err != nil { + panic(err) + } + psubs = append(psubs, ps) } return psubs } @@ -290,11 +294,14 @@ func TestSelfReceive(t *testing.T) { host := getNetHosts(t, ctx, 1)[0] - psub := NewFloodSub(ctx, host) + psub, err := NewFloodSub(ctx, host) + if err != nil { + t.Fatal(err) + } msg := []byte("hello world") - err := psub.Publish("foobar", msg) + err = psub.Publish("foobar", msg) if err != nil { t.Fatal(err) } @@ -487,82 +494,103 @@ func TestValidateOverload(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hosts := getNetHosts(t, ctx, 2) - psubs := getPubsubs(ctx, hosts) - - connect(t, hosts[0], hosts[1]) - topic := "foobar" - - block := make(chan struct{}) - - sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { - <-block - return true - })) - if err != nil { - t.Fatal(err) - } - - time.Sleep(time.Millisecond * 50) - - msgs := []struct { + type msg struct { msg []byte validates bool + } + + tcs := []struct { + msgs []msg + + maxConcurrency int }{ - {msg: []byte("this is a legal message"), validates: true}, - {msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true}, - {msg: []byte("there also is nothing controversial about this message"), validates: true}, - {msg: []byte("also fine"), validates: true}, - {msg: []byte("still, all good"), validates: true}, - {msg: []byte("this is getting boring"), validates: true}, - {msg: []byte("foo"), validates: true}, - {msg: []byte("foobar"), validates: true}, - {msg: []byte("foofoo"), validates: true}, - {msg: []byte("barfoo"), validates: true}, - {msg: []byte("barbar"), validates: false}, + { + maxConcurrency: 10, + msgs: []msg{ + {msg: []byte("this is a legal message"), validates: true}, + {msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true}, + {msg: []byte("there also is nothing controversial about this message"), validates: true}, + {msg: []byte("also fine"), validates: true}, + {msg: []byte("still, all good"), validates: true}, + {msg: []byte("this is getting boring"), validates: true}, + {msg: []byte("foo"), validates: true}, + {msg: []byte("foobar"), validates: true}, + {msg: []byte("foofoo"), validates: true}, + {msg: []byte("barfoo"), validates: true}, + {msg: []byte("oh no!"), validates: false}, + }, + }, + { + maxConcurrency: 2, + msgs: []msg{ + {msg: []byte("this is a legal message"), validates: true}, + {msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true}, + {msg: []byte("oh no!"), validates: false}, + }, + }, } - if len(msgs) != maxConcurrency+1 { - t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(msgs), maxConcurrency+1) - } + for _, tc := range tcs { - p := psubs[0] + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts, WithMaxConcurrency(tc.maxConcurrency)) - var wg sync.WaitGroup - wg.Add(1) - go func() { - for _, tc := range msgs { - select { - case msg := <-sub.ch: - if !tc.validates { - t.Log(msg) - t.Error("expected message validation to drop the message because all validator goroutines are taken") - } - case <-time.After(333 * time.Millisecond): - if tc.validates { - t.Error("expected message validation to accept the message") - } - } - } - wg.Done() - }() + connect(t, hosts[0], hosts[1]) + topic := "foobar" - for i, tc := range msgs { - err := p.Publish(topic, tc.msg) + block := make(chan struct{}) + + sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool { + <-block + return true + })) if err != nil { t.Fatal(err) } - // wait a bit to let pubsub's internal state machine start validating the message - time.Sleep(10 * time.Millisecond) + time.Sleep(time.Millisecond * 50) - // unblock validator goroutines after we sent one too many - if i == len(msgs)-1 { - close(block) + if len(tc.msgs) != tc.maxConcurrency+1 { + t.Fatalf("expected number of messages sent to be defaultMaxConcurrency+1. Got %d, expected %d", len(tc.msgs), tc.maxConcurrency+1) } - } - wg.Wait() + p := psubs[0] + + var wg sync.WaitGroup + wg.Add(1) + go func() { + for _, tmsg := range tc.msgs { + select { + case msg := <-sub.ch: + if !tmsg.validates { + t.Log(msg) + t.Error("expected message validation to drop the message because all validator goroutines are taken") + } + case <-time.After(333 * time.Millisecond): + if tmsg.validates { + t.Error("expected message validation to accept the message") + } + } + } + wg.Done() + }() + + for i, tmsg := range tc.msgs { + err := p.Publish(topic, tmsg.msg) + if err != nil { + t.Fatal(err) + } + + // wait a bit to let pubsub's internal state machine start validating the message + time.Sleep(10 * time.Millisecond) + + // unblock validator goroutines after we sent one too many + if i == len(tc.msgs)-1 { + close(block) + } + } + wg.Wait() + } } func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) { @@ -646,7 +674,10 @@ func TestSubReporting(t *testing.T) { defer cancel() host := getNetHosts(t, ctx, 1)[0] - psub := NewFloodSub(ctx, host) + psub, err := NewFloodSub(ctx, host) + if err != nil { + t.Fatal(err) + } fooSub, err := psub.Subscribe("foo") if err != nil {