diff --git a/topic.go b/topic.go index 6460782..8de88c3 100644 --- a/topic.go +++ b/topic.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "sync" + "time" pb "github.com/libp2p/go-libp2p-pubsub/pb" @@ -242,7 +243,44 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error } if pub.ready != nil { - t.p.disc.Bootstrap(ctx, t.topic, pub.ready) + if t.p.disc.discovery != nil { + t.p.disc.Bootstrap(ctx, t.topic, pub.ready) + } else { + // TODO: we could likely do better than polling every 200ms. + // For example, block this goroutine on a channel, + // and check again whenever events tell us that the number of + // peers has increased. + var ticker *time.Ticker + readyLoop: + for { + // Check if ready for publishing. + // Similar to what disc.Bootstrap does. + res := make(chan bool, 1) + select { + case t.p.eval <- func() { + done, _ := pub.ready(t.p.rt, t.topic) + res <- done + }: + if <-res { + break readyLoop + } + case <-t.p.ctx.Done(): + return t.p.ctx.Err() + case <-ctx.Done(): + return ctx.Err() + } + if ticker == nil { + ticker = time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + } + + select { + case <-ticker.C: + case <-ctx.Done(): + return fmt.Errorf("router is not ready: %w", ctx.Err()) + } + } + } } return t.p.val.PushLocal(&Message{m, t.p.host.ID(), nil}) diff --git a/topic_test.go b/topic_test.go index 2169d35..52927c0 100644 --- a/topic_test.go +++ b/topic_test.go @@ -3,6 +3,7 @@ package pubsub import ( "bytes" "context" + "errors" "fmt" "math/rand" "sync" @@ -780,3 +781,82 @@ func readAllQueuedEvents(ctx context.Context, t *testing.T, evt *TopicEventHandl } return peerState } + +func TestMinTopicSizeNoDiscovery(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + const numHosts = 3 + topicID := "foobar" + hosts := getNetHosts(t, ctx, numHosts) + + sender := getPubsub(ctx, hosts[0]) + receiver1 := getPubsub(ctx, hosts[1]) + receiver2 := getPubsub(ctx, hosts[2]) + + connectAll(t, hosts) + + // Sender creates topic + sendTopic, err := sender.Join(topicID) + if err != nil { + t.Fatal(err) + } + + // Receiver creates and subscribes to the topic + receiveTopic1, err := receiver1.Join(topicID) + if err != nil { + t.Fatal(err) + } + + sub1, err := receiveTopic1.Subscribe() + if err != nil { + t.Fatal(err) + } + + oneMsg := []byte("minimum one") + if err := sendTopic.Publish(ctx, oneMsg, WithReadiness(MinTopicSize(1))); err != nil { + t.Fatal(err) + } + + if msg, err := sub1.Next(ctx); err != nil { + t.Fatal(err) + } else if !bytes.Equal(msg.GetData(), oneMsg) { + t.Fatal("received incorrect message") + } + + twoMsg := []byte("minimum two") + + // Attempting to publish with a minimum topic size of two should fail. + { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + if err := sendTopic.Publish(ctx, twoMsg, WithReadiness(MinTopicSize(2))); !errors.Is(err, context.DeadlineExceeded) { + t.Fatal(err) + } + } + + // Subscribe the second receiver; the publish should now work. + receiveTopic2, err := receiver2.Join(topicID) + if err != nil { + t.Fatal(err) + } + + sub2, err := receiveTopic2.Subscribe() + if err != nil { + t.Fatal(err) + } + + { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + if err := sendTopic.Publish(ctx, twoMsg, WithReadiness(MinTopicSize(2))); err != nil { + t.Fatal(err) + } + } + + if msg, err := sub2.Next(ctx); err != nil { + t.Fatal(err) + } else if !bytes.Equal(msg.GetData(), twoMsg) { + t.Fatal("received incorrect message") + } +}