add network size argument to randomsub

so that sqrt propagation math actually works
This commit is contained in:
vyzo 2020-04-23 20:05:23 +03:00
parent d25506dd2b
commit c353348592
2 changed files with 16 additions and 11 deletions

View File

@ -18,19 +18,21 @@ var (
) )
// NewRandomSub returns a new PubSub object using RandomSubRouter as the router. // NewRandomSub returns a new PubSub object using RandomSubRouter as the router.
func NewRandomSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) { func NewRandomSub(ctx context.Context, h host.Host, size int, opts ...Option) (*PubSub, error) {
rt := &RandomSubRouter{ rt := &RandomSubRouter{
size: size,
peers: make(map[peer.ID]protocol.ID), peers: make(map[peer.ID]protocol.ID),
} }
return NewPubSub(ctx, h, rt, opts...) return NewPubSub(ctx, h, rt, opts...)
} }
// RandomSubRouter is a router that implements a random propagation strategy. // RandomSubRouter is a router that implements a random propagation strategy.
// For each message, it selects the square root of peers, with a min of RandomSubD, // For each message, it selects the square root of the network size peers, with a min of RandomSubD,
// and forwards the message to them. // and forwards the message to them.
type RandomSubRouter struct { type RandomSubRouter struct {
p *PubSub p *PubSub
peers map[peer.ID]protocol.ID peers map[peer.ID]protocol.ID
size int
tracer *pubsubTracer tracer *pubsubTracer
} }
@ -122,10 +124,13 @@ func (rs *RandomSubRouter) Publish(msg *Message) {
if len(rspeers) > RandomSubD { if len(rspeers) > RandomSubD {
target := RandomSubD target := RandomSubD
sqrt := int(math.Ceil(math.Sqrt(float64(len(rspeers))))) sqrt := int(math.Ceil(math.Sqrt(float64(rs.size))))
if sqrt > target { if sqrt > target {
target = sqrt target = sqrt
} }
if target > len(rspeers) {
target = len(rspeers)
}
xpeers := peerMapToList(rspeers) xpeers := peerMapToList(rspeers)
shufflePeers(xpeers) shufflePeers(xpeers)
xpeers = xpeers[:RandomSubD] xpeers = xpeers[:RandomSubD]

View File

@ -9,18 +9,18 @@ import (
"github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/host"
) )
func getRandomsub(ctx context.Context, h host.Host, opts ...Option) *PubSub { func getRandomsub(ctx context.Context, h host.Host, size int, opts ...Option) *PubSub {
ps, err := NewRandomSub(ctx, h, opts...) ps, err := NewRandomSub(ctx, h, size, opts...)
if err != nil { if err != nil {
panic(err) panic(err)
} }
return ps return ps
} }
func getRandomsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSub { func getRandomsubs(ctx context.Context, hs []host.Host, size int, opts ...Option) []*PubSub {
var psubs []*PubSub var psubs []*PubSub
for _, h := range hs { for _, h := range hs {
psubs = append(psubs, getRandomsub(ctx, h, opts...)) psubs = append(psubs, getRandomsub(ctx, h, size, opts...))
} }
return psubs return psubs
} }
@ -30,7 +30,7 @@ func TestRandomsubSmall(t *testing.T) {
defer cancel() defer cancel()
hosts := getNetHosts(t, ctx, 10) hosts := getNetHosts(t, ctx, 10)
psubs := getRandomsubs(ctx, hosts) psubs := getRandomsubs(ctx, hosts, 10)
connectAll(t, hosts) connectAll(t, hosts)
@ -60,7 +60,7 @@ func TestRandomsubBig(t *testing.T) {
defer cancel() defer cancel()
hosts := getNetHosts(t, ctx, 50) hosts := getNetHosts(t, ctx, 50)
psubs := getRandomsubs(ctx, hosts) psubs := getRandomsubs(ctx, hosts, 50)
connectSome(t, hosts, 12) connectSome(t, hosts, 12)
@ -91,7 +91,7 @@ func TestRandomsubMixed(t *testing.T) {
hosts := getNetHosts(t, ctx, 40) hosts := getNetHosts(t, ctx, 40)
fsubs := getPubsubs(ctx, hosts[:10]) fsubs := getPubsubs(ctx, hosts[:10])
rsubs := getRandomsubs(ctx, hosts[10:]) rsubs := getRandomsubs(ctx, hosts[10:], 30)
psubs := append(fsubs, rsubs...) psubs := append(fsubs, rsubs...)
connectSome(t, hosts, 12) connectSome(t, hosts, 12)
@ -123,7 +123,7 @@ func TestRandomsubEnoughPeers(t *testing.T) {
hosts := getNetHosts(t, ctx, 40) hosts := getNetHosts(t, ctx, 40)
fsubs := getPubsubs(ctx, hosts[:10]) fsubs := getPubsubs(ctx, hosts[:10])
rsubs := getRandomsubs(ctx, hosts[10:]) rsubs := getRandomsubs(ctx, hosts[10:], 30)
psubs := append(fsubs, rsubs...) psubs := append(fsubs, rsubs...)
connectSome(t, hosts, 12) connectSome(t, hosts, 12)