diff --git a/floodsub_test.go b/floodsub_test.go index 6da00f9..e9905c4 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -4,17 +4,24 @@ import ( "bytes" "context" "fmt" + "io" "math/rand" "sort" + "sync" "testing" "time" - bhost "github.com/libp2p/go-libp2p-blankhost" + pb "github.com/libp2p/go-libp2p-pubsub/pb" + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" + bhost "github.com/libp2p/go-libp2p-blankhost" swarmt "github.com/libp2p/go-libp2p-swarm/testing" + + ggio "github.com/gogo/protobuf/io" ) func checkMessageRouting(t *testing.T, topic string, pubs []*PubSub, subs []*Subscription) { @@ -969,3 +976,69 @@ func TestConfigurableMaxMessageSize(t *testing.T) { } } + +func TestAnnounceRetry(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + ps := getPubsub(ctx, hosts[0]) + watcher := &announceWatcher{} + hosts[1].SetStreamHandler(FloodSubID, watcher.handleStream) + + _, err := ps.Subscribe("test") + if err != nil { + t.Fatal(err) + } + + // connect the watcher to the pubsub + connect(t, hosts[0], hosts[1]) + + // wait a bit for the first subscription to be emitted and trigger announce retry + time.Sleep(100 * time.Millisecond) + go ps.announceRetry(hosts[1].ID(), "test", true) + + // wait a bit for the subscription to propagate and ensure it was received twice + time.Sleep(time.Second + 100*time.Millisecond) + count := watcher.countSubs() + if count != 2 { + t.Fatalf("expected 2 subscription messages, but got %d", count) + } +} + +type announceWatcher struct { + mx sync.Mutex + subs int +} + +func (aw *announceWatcher) handleStream(s network.Stream) { + defer s.Close() + + r := ggio.NewDelimitedReader(s, 1<<20) + + var rpc pb.RPC + for { + rpc.Reset() + err := r.ReadMsg(&rpc) + if err != nil { + if err != io.EOF { + s.Reset() + } + return + } + + for _, sub := range rpc.GetSubscriptions() { + if sub.GetSubscribe() && sub.GetTopicid() == "test" { + aw.mx.Lock() + aw.subs++ + aw.mx.Unlock() + } + } + } +} + +func (aw *announceWatcher) countSubs() int { + aw.mx.Lock() + defer aw.mx.Unlock() + return aw.subs +}