feat(gossipsub): Add MessageBatch (#607)

to support batch publishing messages

Replaces #602.

Batch publishing lets the system know there are multiple related
messages to be published so it can prioritize sending different messages
before sending copies of messages. For example, with the default API,
when you publish two messages A and B, under the hood A gets sent to D=8
peers first, before B gets sent out. With this MessageBatch api we can
now send one copy of A _and then_ one copy of B before sending multiple
copies.

When a node has bandwidth constraints relative to the messages it is
publishing this improves dissemination time.

For more context see this post:
https://ethresear.ch/t/improving-das-performance-with-gossipsub-batch-publishing/21713
This commit is contained in:
Marco Munizaga 2025-05-08 10:23:02 -07:00 committed by GitHub
parent 50ccc5ca90
commit 0c5ee7bbfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 496 additions and 89 deletions

View File

@ -5,6 +5,7 @@ import (
"crypto/sha256"
"fmt"
"io"
"iter"
"math/rand"
"sort"
"time"
@ -522,6 +523,8 @@ type GossipSubRouter struct {
heartbeatTicks uint64
}
var _ BatchPublisher = &GossipSubRouter{}
type connectInfo struct {
p peer.ID
spr *record.Envelope
@ -1143,81 +1146,105 @@ func (gs *GossipSubRouter) connector() {
}
}
func (gs *GossipSubRouter) Publish(msg *Message) {
gs.mcache.Put(msg)
from := msg.ReceivedFrom
topic := msg.GetTopic()
tosend := make(map[peer.ID]struct{})
// any peers in the topic?
tmap, ok := gs.p.topics[topic]
if !ok {
return
func (gs *GossipSubRouter) PublishBatch(messages []*Message, opts *BatchPublishOptions) {
strategy := opts.Strategy
for _, msg := range messages {
msgID := gs.p.idGen.ID(msg)
for p, rpc := range gs.rpcs(msg) {
strategy.AddRPC(p, msgID, rpc)
}
}
if gs.floodPublish && from == gs.p.host.ID() {
for p := range tmap {
_, direct := gs.direct[p]
if direct || gs.score.Score(p) >= gs.publishThreshold {
tosend[p] = struct{}{}
}
}
} else {
// direct peers
for p := range gs.direct {
_, inTopic := tmap[p]
if inTopic {
tosend[p] = struct{}{}
}
}
for p, rpc := range strategy.All() {
gs.sendRPC(p, rpc, false)
}
}
// floodsub peers
for p := range tmap {
if !gs.feature(GossipSubFeatureMesh, gs.peers[p]) && gs.score.Score(p) >= gs.publishThreshold {
tosend[p] = struct{}{}
}
}
func (gs *GossipSubRouter) Publish(msg *Message) {
for p, rpc := range gs.rpcs(msg) {
gs.sendRPC(p, rpc, false)
}
}
// gossipsub peers
gmap, ok := gs.mesh[topic]
func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] {
return func(yield func(peer.ID, *RPC) bool) {
gs.mcache.Put(msg)
from := msg.ReceivedFrom
topic := msg.GetTopic()
tosend := make(map[peer.ID]struct{})
// any peers in the topic?
tmap, ok := gs.p.topics[topic]
if !ok {
// we are not in the mesh for topic, use fanout peers
gmap, ok = gs.fanout[topic]
if !ok || len(gmap) == 0 {
// we don't have any, pick some with score above the publish threshold
peers := gs.getPeers(topic, gs.params.D, func(p peer.ID) bool {
_, direct := gs.direct[p]
return !direct && gs.score.Score(p) >= gs.publishThreshold
})
return
}
if len(peers) > 0 {
gmap = peerListToMap(peers)
gs.fanout[topic] = gmap
if gs.floodPublish && from == gs.p.host.ID() {
for p := range tmap {
_, direct := gs.direct[p]
if direct || gs.score.Score(p) >= gs.publishThreshold {
tosend[p] = struct{}{}
}
}
gs.lastpub[topic] = time.Now().UnixNano()
} else {
// direct peers
for p := range gs.direct {
_, inTopic := tmap[p]
if inTopic {
tosend[p] = struct{}{}
}
}
// floodsub peers
for p := range tmap {
if !gs.feature(GossipSubFeatureMesh, gs.peers[p]) && gs.score.Score(p) >= gs.publishThreshold {
tosend[p] = struct{}{}
}
}
// gossipsub peers
gmap, ok := gs.mesh[topic]
if !ok {
// we are not in the mesh for topic, use fanout peers
gmap, ok = gs.fanout[topic]
if !ok || len(gmap) == 0 {
// we don't have any, pick some with score above the publish threshold
peers := gs.getPeers(topic, gs.params.D, func(p peer.ID) bool {
_, direct := gs.direct[p]
return !direct && gs.score.Score(p) >= gs.publishThreshold
})
if len(peers) > 0 {
gmap = peerListToMap(peers)
gs.fanout[topic] = gmap
}
}
gs.lastpub[topic] = time.Now().UnixNano()
}
csum := computeChecksum(gs.p.idGen.ID(msg))
for p := range gmap {
// Check if it has already received an IDONTWANT for the message.
// If so, don't send it to the peer
if _, ok := gs.unwanted[p][csum]; ok {
continue
}
tosend[p] = struct{}{}
}
}
csum := computeChecksum(gs.p.idGen.ID(msg))
for p := range gmap {
// Check if it has already received an IDONTWANT for the message.
// If so, don't send it to the peer
if _, ok := gs.unwanted[p][csum]; ok {
out := rpcWithMessages(msg.Message)
for pid := range tosend {
if pid == from || pid == peer.ID(msg.GetFrom()) {
continue
}
tosend[p] = struct{}{}
}
}
out := rpcWithMessages(msg.Message)
for pid := range tosend {
if pid == from || pid == peer.ID(msg.GetFrom()) {
continue
if !yield(pid, out) {
return
}
}
gs.sendRPC(pid, out, false)
}
}

View File

@ -9,9 +9,11 @@ import (
"io"
mrand "math/rand"
"sort"
"strings"
"sync"
"sync/atomic"
"testing"
"testing/quick"
"time"
pb "github.com/libp2p/go-libp2p-pubsub/pb"
@ -3406,3 +3408,209 @@ func BenchmarkAllocDoDropRPC(b *testing.B) {
gs.doDropRPC(&RPC{}, "peerID", "reason")
}
}
func TestRoundRobinMessageIDScheduler(t *testing.T) {
const maxNumPeers = 256
const maxNumMessages = 1_000
err := quick.Check(func(numPeers uint16, numMessages uint16) bool {
numPeers = numPeers % maxNumPeers
numMessages = numMessages % maxNumMessages
output := make([]pendingRPC, 0, numMessages*numPeers)
var strategy RoundRobinMessageIDScheduler
peers := make([]peer.ID, numPeers)
for i := 0; i < int(numPeers); i++ {
peers[i] = peer.ID(fmt.Sprintf("peer%d", i))
}
getID := func(r pendingRPC) string {
return string(r.rpc.Publish[0].Data)
}
for i := range int(numMessages) {
for j := range int(numPeers) {
strategy.AddRPC(peers[j], fmt.Sprintf("msg%d", i), &RPC{
RPC: pb.RPC{
Publish: []*pb.Message{
{
Data: []byte(fmt.Sprintf("msg%d", i)),
},
},
},
})
}
}
for p, rpc := range strategy.All() {
output = append(output, pendingRPC{
peer: p,
rpc: rpc,
})
}
// Check invariants
// 1. The published rpcs count is the same as the number of messages added
// 2. Before all message IDs are seen, no message ID may be repeated
// 3. The set of message ID + peer ID combinations should be the same as the input
// 1.
expectedCount := int(numMessages) * int(numPeers)
if len(output) != expectedCount {
t.Logf("Expected %d RPCs, got %d", expectedCount, len(output))
return false
}
// 2.
seen := make(map[string]bool)
expected := make(map[string]bool)
for i := 0; i < int(numMessages); i++ {
expected[fmt.Sprintf("msg%d", i)] = true
}
for _, rpc := range output {
if expected[getID(rpc)] {
delete(expected, getID(rpc))
}
if seen[getID(rpc)] && len(expected) > 0 {
t.Logf("Message ID %s repeated before all message IDs are seen", getID(rpc))
return false
}
seen[getID(rpc)] = true
}
// 3.
inputSet := make(map[string]bool)
for i := range int(numMessages) {
for j := range int(numPeers) {
inputSet[fmt.Sprintf("msg%d:peer%d", i, j)] = true
}
}
for _, rpc := range output {
if !inputSet[getID(rpc)+":"+string(rpc.peer)] {
t.Logf("Message ID %s not in input", getID(rpc))
return false
}
}
return true
}, &quick.Config{MaxCount: 32})
if err != nil {
t.Fatal(err)
}
}
func BenchmarkRoundRobinMessageIDScheduler(b *testing.B) {
const numPeers = 1_000
const numMessages = 1_000
var strategy RoundRobinMessageIDScheduler
peers := make([]peer.ID, numPeers)
for i := range int(numPeers) {
peers[i] = peer.ID(fmt.Sprintf("peer%d", i))
}
msgs := make([]string, numMessages)
for i := range numMessages {
msgs[i] = fmt.Sprintf("msg%d", i)
}
emptyRPC := &RPC{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
j := i % len(peers)
msgIdx := i % numMessages
strategy.AddRPC(peers[j], msgs[msgIdx], emptyRPC)
if i%100 == 0 {
for range strategy.All() {
}
}
}
}
func TestMessageBatchPublish(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hosts := getDefaultHosts(t, 20)
msgIDFn := func(msg *pb.Message) string {
hdr := string(msg.Data[0:16])
msgID := strings.SplitN(hdr, " ", 2)
return msgID[0]
}
const numMessages = 100
// +8 to account for the gossiping overhead
psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(msgIDFn), WithPeerOutboundQueueSize(numMessages+8))
var topics []*Topic
var msgs []*Subscription
for _, ps := range psubs {
topic, err := ps.Join("foobar")
if err != nil {
t.Fatal(err)
}
topics = append(topics, topic)
subch, err := topic.Subscribe(WithBufferSize(numMessages + 8))
if err != nil {
t.Fatal(err)
}
msgs = append(msgs, subch)
}
sparseConnect(t, hosts)
// wait for heartbeats to build mesh
time.Sleep(time.Second * 2)
var batch MessageBatch
for i := 0; i < numMessages; i++ {
msg := []byte(fmt.Sprintf("%d it's not a floooooood %d", i, i))
err := topics[0].AddToBatch(ctx, &batch, msg)
if err != nil {
t.Fatal(err)
}
}
err := psubs[0].PublishBatch(&batch)
if err != nil {
t.Fatal(err)
}
for range numMessages {
for _, sub := range msgs {
got, err := sub.Next(ctx)
if err != nil {
t.Fatal(sub.err)
}
id := msgIDFn(got.Message)
expected := []byte(fmt.Sprintf("%s it's not a floooooood %s", id, id))
if !bytes.Equal(expected, got.Data) {
t.Fatal("got wrong message!")
}
}
}
}
func TestPublishDuplicateMessage(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hosts := getDefaultHosts(t, 1)
psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(func(msg *pb.Message) string {
return string(msg.Data)
}))
topic, err := psubs[0].Join("foobar")
if err != nil {
t.Fatal(err)
}
err = topic.Publish(ctx, []byte("hello"))
if err != nil {
t.Fatal(err)
}
err = topic.Publish(ctx, []byte("hello"))
if err != nil {
t.Fatal("Duplicate message should not return an error")
}
}

62
messagebatch.go Normal file
View File

@ -0,0 +1,62 @@
package pubsub
import (
"iter"
"github.com/libp2p/go-libp2p/core/peer"
)
// MessageBatch allows a user to batch related messages and then publish them at
// once. This allows the Scheduler to define an order for outgoing RPCs.
// This helps bandwidth constrained peers.
type MessageBatch struct {
messages []*Message
}
type messageBatchAndPublishOptions struct {
messages []*Message
opts *BatchPublishOptions
}
// RPCScheduler schedules outgoing RPCs.
type RPCScheduler interface {
// AddRPC adds an RPC to the scheduler.
AddRPC(peer peer.ID, msgID string, rpc *RPC)
// All returns an ordered iterator of RPCs.
All() iter.Seq2[peer.ID, *RPC]
}
type pendingRPC struct {
peer peer.ID
rpc *RPC
}
// RoundRobinMessageIDScheduler schedules outgoing RPCs in round-robin order of message IDs.
type RoundRobinMessageIDScheduler struct {
rpcs map[string][]pendingRPC
}
func (s *RoundRobinMessageIDScheduler) AddRPC(peer peer.ID, msgID string, rpc *RPC) {
if s.rpcs == nil {
s.rpcs = make(map[string][]pendingRPC)
}
s.rpcs[msgID] = append(s.rpcs[msgID], pendingRPC{peer: peer, rpc: rpc})
}
func (s *RoundRobinMessageIDScheduler) All() iter.Seq2[peer.ID, *RPC] {
return func(yield func(peer.ID, *RPC) bool) {
for len(s.rpcs) > 0 {
for msgID, rpcs := range s.rpcs {
if len(rpcs) == 0 {
delete(s.rpcs, msgID)
continue
}
if !yield(rpcs[0].peer, rpcs[0].rpc) {
return
}
s.rpcs[msgID] = rpcs[1:]
}
}
}
}

View File

@ -134,6 +134,9 @@ type PubSub struct {
// sendMsg handles messages that have been validated
sendMsg chan *Message
// sendMessageBatch publishes a batch of messages
sendMessageBatch chan messageBatchAndPublishOptions
// addVal handles validator registration requests
addVal chan *addValReq
@ -217,6 +220,10 @@ type PubSubRouter interface {
Leave(topic string)
}
type BatchPublisher interface {
PublishBatch(messages []*Message, opts *BatchPublishOptions)
}
type AcceptStatus int
const (
@ -281,6 +288,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option
rmTopic: make(chan *rmTopicReq),
getTopics: make(chan *topicReq),
sendMsg: make(chan *Message, 32),
sendMessageBatch: make(chan messageBatchAndPublishOptions, 1),
addVal: make(chan *addValReq),
rmVal: make(chan *rmValReq),
eval: make(chan func()),
@ -642,6 +650,9 @@ func (p *PubSub) processLoop(ctx context.Context) {
case msg := <-p.sendMsg:
p.publishMessage(msg)
case batchAndOpts := <-p.sendMessageBatch:
p.publishMessageBatch(batchAndOpts)
case req := <-p.addVal:
p.val.AddValidator(req)
@ -1221,6 +1232,15 @@ func (p *PubSub) publishMessage(msg *Message) {
}
}
func (p *PubSub) publishMessageBatch(batchAndOpts messageBatchAndPublishOptions) {
for _, msg := range batchAndOpts.messages {
p.tracer.DeliverMessage(msg)
p.notifySubs(msg)
}
// We type checked when pushing the batch to the channel
p.rt.(BatchPublisher).PublishBatch(batchAndOpts.messages, batchAndOpts.opts)
}
type addTopicReq struct {
topic *Topic
resp chan *Topic
@ -1358,6 +1378,39 @@ func (p *PubSub) Publish(topic string, data []byte, opts ...PubOpt) error {
return t.Publish(context.TODO(), data, opts...)
}
// PublishBatch publishes a batch of messages. This only works for routers that
// implement the BatchPublisher interface.
//
// Users should make sure there is enough space in the Peer's outbound queue to
// ensure messages are not dropped. WithPeerOutboundQueueSize should be set to
// at least the expected number of batched messages per peer plus some slack to
// account for gossip messages.
//
// The default publish strategy is RoundRobinMessageIDScheduler.
func (p *PubSub) PublishBatch(batch *MessageBatch, opts ...BatchPubOpt) error {
if _, ok := p.rt.(BatchPublisher); !ok {
return fmt.Errorf("pubsub router is not a BatchPublisher")
}
publishOptions := &BatchPublishOptions{}
for _, o := range opts {
err := o(publishOptions)
if err != nil {
return err
}
}
setDefaultBatchPublishOptions(publishOptions)
p.sendMessageBatch <- messageBatchAndPublishOptions{
messages: batch.messages,
opts: publishOptions,
}
// Clear the batch's messages in case a user reuses the same batch object
batch.messages = nil
return nil
}
func (p *PubSub) nextSeqno() []byte {
seqno := make([]byte, 8)
counter := atomic.AddUint64(&p.counter, 1)

View File

@ -219,14 +219,53 @@ type PublishOptions struct {
validatorData any
}
type BatchPublishOptions struct {
Strategy RPCScheduler
}
type PubOpt func(pub *PublishOptions) error
type BatchPubOpt func(pub *BatchPublishOptions) error
func setDefaultBatchPublishOptions(opts *BatchPublishOptions) {
if opts.Strategy == nil {
opts.Strategy = &RoundRobinMessageIDScheduler{}
}
}
// Publish publishes data to topic.
func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error {
msg, err := t.validate(ctx, data, opts...)
if err != nil {
if errors.Is(err, dupeErr{}) {
// If it was a duplicate, we return nil to indicate success.
// Semantically the message was published by us or someone else.
return nil
}
return err
}
return t.p.val.sendMsgBlocking(msg)
}
func (t *Topic) AddToBatch(ctx context.Context, batch *MessageBatch, data []byte, opts ...PubOpt) error {
msg, err := t.validate(ctx, data, opts...)
if err != nil {
if errors.Is(err, dupeErr{}) {
// If it was a duplicate, we return nil to indicate success.
// Semantically the message was published by us or someone else.
// We won't add it to the batch. Since it's already been published.
return nil
}
return err
}
batch.messages = append(batch.messages, msg)
return nil
}
func (t *Topic) validate(ctx context.Context, data []byte, opts ...PubOpt) (*Message, error) {
t.mux.RLock()
defer t.mux.RUnlock()
if t.closed {
return ErrTopicClosed
return nil, ErrTopicClosed
}
pid := t.p.signID
@ -236,17 +275,17 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error
for _, opt := range opts {
err := opt(pub)
if err != nil {
return err
return nil, err
}
}
if pub.customKey != nil && !pub.local {
key, pid = pub.customKey()
if key == nil {
return ErrNilSignKey
return nil, ErrNilSignKey
}
if len(pid) == 0 {
return ErrEmptyPeerID
return nil, ErrEmptyPeerID
}
}
@ -264,7 +303,7 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error
m.From = []byte(pid)
err := signMessage(pid, key, m)
if err != nil {
return err
return nil, err
}
}
@ -291,9 +330,9 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error
break readyLoop
}
case <-t.p.ctx.Done():
return t.p.ctx.Err()
return nil, t.p.ctx.Err()
case <-ctx.Done():
return ctx.Err()
return nil, ctx.Err()
}
if ticker == nil {
ticker = time.NewTicker(200 * time.Millisecond)
@ -303,13 +342,18 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error
select {
case <-ticker.C:
case <-ctx.Done():
return fmt.Errorf("router is not ready: %w", ctx.Err())
return nil, fmt.Errorf("router is not ready: %w", ctx.Err())
}
}
}
}
return t.p.val.PushLocal(&Message{m, "", t.p.host.ID(), pub.validatorData, pub.local})
msg := &Message{m, "", t.p.host.ID(), pub.validatorData, pub.local}
err := t.p.val.ValidateLocal(msg)
if err != nil {
return nil, err
}
return msg, nil
}
// WithReadiness returns a publishing option for only publishing when the router is ready.

View File

@ -26,6 +26,12 @@ func (e ValidationError) Error() string {
return e.Reason
}
type dupeErr struct{}
func (dupeErr) Error() string {
return "duplicate message"
}
// Validator is a function that validates a message with a binary decision: accept or reject.
type Validator func(context.Context, peer.ID, *Message) bool
@ -226,10 +232,9 @@ func (v *validation) RemoveValidator(req *rmValReq) {
}
}
// PushLocal synchronously pushes a locally published message and performs applicable
// validations.
// Returns an error if validation fails
func (v *validation) PushLocal(msg *Message) error {
// ValidateLocal synchronously validates a locally published message and
// performs applicable validations. Returns an error if validation fails.
func (v *validation) ValidateLocal(msg *Message) error {
v.p.tracer.PublishMessage(msg)
err := v.p.checkSigningPolicy(msg)
@ -238,7 +243,9 @@ func (v *validation) PushLocal(msg *Message) error {
}
vals := v.getValidators(msg)
return v.validate(vals, msg.ReceivedFrom, msg, true)
return v.validate(vals, msg.ReceivedFrom, msg, true, func(msg *Message) error {
return nil
})
}
// Push pushes a message into the validation pipeline.
@ -282,15 +289,26 @@ func (v *validation) validateWorker() {
for {
select {
case req := <-v.validateQ:
v.validate(req.vals, req.src, req.msg, false)
_ = v.validate(req.vals, req.src, req.msg, false, v.sendMsgBlocking)
case <-v.p.ctx.Done():
return
}
}
}
// validate performs validation and only sends the message if all validators succeed
func (v *validation) validate(vals []*validatorImpl, src peer.ID, msg *Message, synchronous bool) error {
func (v *validation) sendMsgBlocking(msg *Message) error {
select {
case v.p.sendMsg <- msg:
return nil
case <-v.p.ctx.Done():
return v.p.ctx.Err()
}
}
// validate performs validation and only calls onValid if all validators succeed.
// If synchronous is true, onValid will be called before this function returns
// if the message is new and accepted.
func (v *validation) validate(vals []*validatorImpl, src peer.ID, msg *Message, synchronous bool, onValid func(*Message) error) error {
// If signature verification is enabled, but signing is disabled,
// the Signature is required to be nil upon receiving the message in PubSub.pushMsg.
if msg.Signature != nil {
@ -306,7 +324,7 @@ func (v *validation) validate(vals []*validatorImpl, src peer.ID, msg *Message,
id := v.p.idGen.ID(msg)
if !v.p.markSeen(id) {
v.tracer.DuplicateMessage(msg)
return nil
return dupeErr{}
} else {
v.tracer.ValidateMessage(msg)
}
@ -345,7 +363,7 @@ loop:
select {
case v.validateThrottle <- struct{}{}:
go func() {
v.doValidateTopic(async, src, msg, result)
v.doValidateTopic(async, src, msg, result, onValid)
<-v.validateThrottle
}()
default:
@ -360,13 +378,8 @@ loop:
return ValidationError{Reason: RejectValidationIgnored}
}
// no async validators, accepted message, send it!
select {
case v.p.sendMsg <- msg:
return nil
case <-v.p.ctx.Done():
return v.p.ctx.Err()
}
// no async validators, accepted message
return onValid(msg)
}
func (v *validation) validateSignature(msg *Message) bool {
@ -379,7 +392,7 @@ func (v *validation) validateSignature(msg *Message) bool {
return true
}
func (v *validation) doValidateTopic(vals []*validatorImpl, src peer.ID, msg *Message, r ValidationResult) {
func (v *validation) doValidateTopic(vals []*validatorImpl, src peer.ID, msg *Message, r ValidationResult, onValid func(*Message) error) {
result := v.validateTopic(vals, src, msg)
if result == ValidationAccept && r != ValidationAccept {
@ -388,7 +401,7 @@ func (v *validation) doValidateTopic(vals []*validatorImpl, src peer.ID, msg *Me
switch result {
case ValidationAccept:
v.p.sendMsg <- msg
_ = onValid(msg)
case ValidationReject:
log.Debugf("message validation failed; dropping message from %s", src)
v.tracer.RejectMessage(msg, RejectValidationFailed)