From c353348592c5528c0b998b6b17c7aadc9c0d249c Mon Sep 17 00:00:00 2001 From: vyzo Date: Thu, 23 Apr 2020 20:05:23 +0300 Subject: [PATCH] add network size argument to randomsub so that sqrt propagation math actually works --- randomsub.go | 11 ++++++++--- randomsub_test.go | 16 ++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/randomsub.go b/randomsub.go index 497f5fe..7726b9e 100644 --- a/randomsub.go +++ b/randomsub.go @@ -18,19 +18,21 @@ var ( ) // NewRandomSub returns a new PubSub object using RandomSubRouter as the router. -func NewRandomSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) { +func NewRandomSub(ctx context.Context, h host.Host, size int, opts ...Option) (*PubSub, error) { rt := &RandomSubRouter{ + size: size, peers: make(map[peer.ID]protocol.ID), } return NewPubSub(ctx, h, rt, opts...) } // RandomSubRouter is a router that implements a random propagation strategy. -// For each message, it selects the square root of peers, with a min of RandomSubD, +// For each message, it selects the square root of the network size peers, with a min of RandomSubD, // and forwards the message to them. type RandomSubRouter struct { p *PubSub peers map[peer.ID]protocol.ID + size int tracer *pubsubTracer } @@ -122,10 +124,13 @@ func (rs *RandomSubRouter) Publish(msg *Message) { if len(rspeers) > RandomSubD { target := RandomSubD - sqrt := int(math.Ceil(math.Sqrt(float64(len(rspeers))))) + sqrt := int(math.Ceil(math.Sqrt(float64(rs.size)))) if sqrt > target { target = sqrt } + if target > len(rspeers) { + target = len(rspeers) + } xpeers := peerMapToList(rspeers) shufflePeers(xpeers) xpeers = xpeers[:RandomSubD] diff --git a/randomsub_test.go b/randomsub_test.go index 764bdfa..9873138 100644 --- a/randomsub_test.go +++ b/randomsub_test.go @@ -9,18 +9,18 @@ import ( "github.com/libp2p/go-libp2p-core/host" ) -func getRandomsub(ctx context.Context, h host.Host, opts ...Option) *PubSub { - ps, err := NewRandomSub(ctx, h, opts...) +func getRandomsub(ctx context.Context, h host.Host, size int, opts ...Option) *PubSub { + ps, err := NewRandomSub(ctx, h, size, opts...) if err != nil { panic(err) } return ps } -func getRandomsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSub { +func getRandomsubs(ctx context.Context, hs []host.Host, size int, opts ...Option) []*PubSub { var psubs []*PubSub for _, h := range hs { - psubs = append(psubs, getRandomsub(ctx, h, opts...)) + psubs = append(psubs, getRandomsub(ctx, h, size, opts...)) } return psubs } @@ -30,7 +30,7 @@ func TestRandomsubSmall(t *testing.T) { defer cancel() hosts := getNetHosts(t, ctx, 10) - psubs := getRandomsubs(ctx, hosts) + psubs := getRandomsubs(ctx, hosts, 10) connectAll(t, hosts) @@ -60,7 +60,7 @@ func TestRandomsubBig(t *testing.T) { defer cancel() hosts := getNetHosts(t, ctx, 50) - psubs := getRandomsubs(ctx, hosts) + psubs := getRandomsubs(ctx, hosts, 50) connectSome(t, hosts, 12) @@ -91,7 +91,7 @@ func TestRandomsubMixed(t *testing.T) { hosts := getNetHosts(t, ctx, 40) fsubs := getPubsubs(ctx, hosts[:10]) - rsubs := getRandomsubs(ctx, hosts[10:]) + rsubs := getRandomsubs(ctx, hosts[10:], 30) psubs := append(fsubs, rsubs...) connectSome(t, hosts, 12) @@ -123,7 +123,7 @@ func TestRandomsubEnoughPeers(t *testing.T) { hosts := getNetHosts(t, ctx, 40) fsubs := getPubsubs(ctx, hosts[:10]) - rsubs := getRandomsubs(ctx, hosts[10:]) + rsubs := getRandomsubs(ctx, hosts[10:], 30) psubs := append(fsubs, rsubs...) connectSome(t, hosts, 12)