From fedbccc0c69d091d1fcbcfe5709a3f03a2aecb9c Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 25 Jun 2025 12:38:21 -0700 Subject: [PATCH] fix(BatchPublishing): Make topic.AddToBatch threadsafe (#622) topic.Publish is already thread safe. topic.AddToBatch should strive to follow similar semantics. Looking at how this would integrate with Prysm, they use separate goroutines per message they'd like to batch. --- gossipsub_test.go | 117 +++++++++++++++++++++++++++------------------- messagebatch.go | 16 +++++++ pubsub.go | 4 +- topic.go | 2 +- 4 files changed, 86 insertions(+), 53 deletions(-) diff --git a/gossipsub_test.go b/gossipsub_test.go index 7aa5188..9f450d8 100644 --- a/gossipsub_test.go +++ b/gossipsub_test.go @@ -3682,66 +3682,85 @@ func BenchmarkRoundRobinMessageIDScheduler(b *testing.B) { } func TestMessageBatchPublish(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - hosts := getDefaultHosts(t, 20) + concurrentAdds := []bool{false, true} + for _, concurrentAdd := range concurrentAdds { + t.Run(fmt.Sprintf("WithConcurrentAdd=%v", concurrentAdd), func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + hosts := getDefaultHosts(t, 20) - msgIDFn := func(msg *pb.Message) string { - hdr := string(msg.Data[0:16]) - msgID := strings.SplitN(hdr, " ", 2) - return msgID[0] - } - const numMessages = 100 - // +8 to account for the gossiping overhead - psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(msgIDFn), WithPeerOutboundQueueSize(numMessages+8), WithValidateQueueSize(numMessages+8)) + msgIDFn := func(msg *pb.Message) string { + hdr := string(msg.Data[0:16]) + msgID := strings.SplitN(hdr, " ", 2) + return msgID[0] + } + const numMessages = 100 + // +8 to account for the gossiping overhead + psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(msgIDFn), WithPeerOutboundQueueSize(numMessages+8), WithValidateQueueSize(numMessages+8)) - var topics []*Topic - var msgs []*Subscription - for _, ps := range psubs { - topic, err := ps.Join("foobar") - if err != nil { - t.Fatal(err) - } - topics = append(topics, topic) + var topics []*Topic + var msgs []*Subscription + for _, ps := range psubs { + topic, err := ps.Join("foobar") + if err != nil { + t.Fatal(err) + } + topics = append(topics, topic) - subch, err := topic.Subscribe(WithBufferSize(numMessages + 8)) - if err != nil { - t.Fatal(err) - } + subch, err := topic.Subscribe(WithBufferSize(numMessages + 8)) + if err != nil { + t.Fatal(err) + } - msgs = append(msgs, subch) - } + msgs = append(msgs, subch) + } - sparseConnect(t, hosts) + sparseConnect(t, hosts) - // wait for heartbeats to build mesh - time.Sleep(time.Second * 2) + // wait for heartbeats to build mesh + time.Sleep(time.Second * 2) - var batch MessageBatch - for i := 0; i < numMessages; i++ { - msg := []byte(fmt.Sprintf("%d it's not a floooooood %d", i, i)) - err := topics[0].AddToBatch(ctx, &batch, msg) - if err != nil { - t.Fatal(err) - } - } - err := psubs[0].PublishBatch(&batch) - if err != nil { - t.Fatal(err) - } - - for range numMessages { - for _, sub := range msgs { - got, err := sub.Next(ctx) + var batch MessageBatch + var wg sync.WaitGroup + for i := 0; i < numMessages; i++ { + msg := []byte(fmt.Sprintf("%d it's not a floooooood %d", i, i)) + if concurrentAdd { + wg.Add(1) + go func() { + defer wg.Done() + err := topics[0].AddToBatch(ctx, &batch, msg) + if err != nil { + t.Log(err) + t.Fail() + } + }() + } else { + err := topics[0].AddToBatch(ctx, &batch, msg) + if err != nil { + t.Fatal(err) + } + } + } + wg.Wait() + err := psubs[0].PublishBatch(&batch) if err != nil { t.Fatal(err) } - id := msgIDFn(got.Message) - expected := []byte(fmt.Sprintf("%s it's not a floooooood %s", id, id)) - if !bytes.Equal(expected, got.Data) { - t.Fatal("got wrong message!") + + for range numMessages { + for _, sub := range msgs { + got, err := sub.Next(ctx) + if err != nil { + t.Fatal(err) + } + id := msgIDFn(got.Message) + expected := []byte(fmt.Sprintf("%s it's not a floooooood %s", id, id)) + if !bytes.Equal(expected, got.Data) { + t.Fatal("got wrong message!") + } + } } - } + }) } } diff --git a/messagebatch.go b/messagebatch.go index 8178645..55941d0 100644 --- a/messagebatch.go +++ b/messagebatch.go @@ -2,6 +2,7 @@ package pubsub import ( "iter" + "sync" "github.com/libp2p/go-libp2p/core/peer" ) @@ -10,9 +11,24 @@ import ( // once. This allows the Scheduler to define an order for outgoing RPCs. // This helps bandwidth constrained peers. type MessageBatch struct { + mu sync.Mutex messages []*Message } +func (mb *MessageBatch) add(msg *Message) { + mb.mu.Lock() + defer mb.mu.Unlock() + mb.messages = append(mb.messages, msg) +} + +func (mb *MessageBatch) take() []*Message { + mb.mu.Lock() + defer mb.mu.Unlock() + messages := mb.messages + mb.messages = nil + return messages +} + type messageBatchAndPublishOptions struct { messages []*Message opts *BatchPublishOptions diff --git a/pubsub.go b/pubsub.go index 91017d1..3af9888 100644 --- a/pubsub.go +++ b/pubsub.go @@ -1600,12 +1600,10 @@ func (p *PubSub) PublishBatch(batch *MessageBatch, opts ...BatchPubOpt) error { setDefaultBatchPublishOptions(publishOptions) p.sendMessageBatch <- messageBatchAndPublishOptions{ - messages: batch.messages, + messages: batch.take(), opts: publishOptions, } - // Clear the batch's messages in case a user reuses the same batch object - batch.messages = nil return nil } diff --git a/topic.go b/topic.go index b164e32..c438ebc 100644 --- a/topic.go +++ b/topic.go @@ -257,7 +257,7 @@ func (t *Topic) AddToBatch(ctx context.Context, batch *MessageBatch, data []byte } return err } - batch.messages = append(batch.messages, msg) + batch.add(msg) return nil }