diff --git a/floodsub.go b/floodsub.go index 5230ad5..00637e3 100644 --- a/floodsub.go +++ b/floodsub.go @@ -14,18 +14,26 @@ const ( FloodSubID = protocol.ID("/floodsub/1.0.0") ) -// NewFloodSub returns a new PubSub object using the FloodSubRouter -func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) { - rt := &FloodSubRouter{} +// New returns a new floodsub-enabled PubSub objecting using the protocols specified in ps +func New(ctx context.Context, h host.Host, ps []protocol.ID, opts ...Option) (*PubSub, error) { + rt := &FloodSubRouter{ + protocols: ps, + } return NewPubSub(ctx, h, rt, opts...) } +// NewFloodSub returns a new PubSub object using the FloodSubRouter +func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) { + return New(ctx, h, []protocol.ID{FloodSubID}, opts...) +} + type FloodSubRouter struct { - p *PubSub + p *PubSub + protocols []protocol.ID } func (fs *FloodSubRouter) Protocols() []protocol.ID { - return []protocol.ID{FloodSubID} + return fs.protocols } func (fs *FloodSubRouter) Attach(p *PubSub) { diff --git a/floodsub_test.go b/floodsub_test.go index 28262dd..baa7b40 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -15,6 +15,7 @@ import ( swarmt "github.com/libp2p/go-libp2p-swarm/testing" //bhost "github.com/libp2p/go-libp2p/p2p/host/basic" bhost "github.com/libp2p/go-libp2p-blankhost" + "github.com/libp2p/go-libp2p-protocol" ) func checkMessageRouting(t *testing.T, topic string, pubs []*PubSub, subs []*Subscription) { @@ -609,6 +610,94 @@ func assertHasTopics(t *testing.T, ps *PubSub, exptopics ...string) { } } +func TestFloodSubPluggableProtocol(t *testing.T) { + t.Run("multi-procol router acts like a hub", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 3) + + psubA := mustCreatePubSub(ctx, t, hosts[0], "/esh/floodsub", "/lsr/floodsub") + psubB := mustCreatePubSub(ctx, t, hosts[1], "/esh/floodsub") + psubC := mustCreatePubSub(ctx, t, hosts[2], "/lsr/floodsub") + + subA := mustSubscribe(t, psubA, "foobar") + defer subA.Cancel() + + subB := mustSubscribe(t, psubB, "foobar") + defer subB.Cancel() + + subC := mustSubscribe(t, psubC, "foobar") + defer subC.Cancel() + + // B --> A, C --> A + connect(t, hosts[1], hosts[0]) + connect(t, hosts[2], hosts[0]) + + time.Sleep(time.Millisecond * 100) + + psubC.Publish("foobar", []byte("bar")) + + assertReceive(t, subA, []byte("bar")) + assertReceive(t, subB, []byte("bar")) + assertReceive(t, subC, []byte("bar")) + }) + + t.Run("won't talk to routers with no protocol overlap", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + + psubA := mustCreatePubSub(ctx, t, hosts[0], "/esh/floodsub") + psubB := mustCreatePubSub(ctx, t, hosts[1], "/lsr/floodsub") + + subA := mustSubscribe(t, psubA, "foobar") + defer subA.Cancel() + + subB := mustSubscribe(t, psubB, "foobar") + defer subB.Cancel() + + connect(t, hosts[1], hosts[0]) + + time.Sleep(time.Millisecond * 100) + + psubA.Publish("foobar", []byte("bar")) + + assertReceive(t, subA, []byte("bar")) + + pass := false + select { + case <-subB.ch: + t.Fatal("different protocols: should not have received message") + case <-time.After(time.Second * 1): + pass = true + } + + if !pass { + t.Fatal("should have timed out waiting for message") + } + }) +} + +func mustCreatePubSub(ctx context.Context, t *testing.T, h host.Host, ps ...protocol.ID) *PubSub { + psub, err := New(ctx, h, ps) + if err != nil { + t.Fatal(err) + } + + return psub +} + +func mustSubscribe(t *testing.T, ps *PubSub, topic string) *Subscription { + sub, err := ps.Subscribe(topic) + if err != nil { + t.Fatal(err) + } + + return sub +} + func TestSubReporting(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel()