diff --git a/gossipsub.go b/gossipsub.go index ecd4eda..b5a605a 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "fmt" "io" + "iter" "math/rand" "sort" "time" @@ -522,6 +523,8 @@ type GossipSubRouter struct { heartbeatTicks uint64 } +var _ BatchPublisher = &GossipSubRouter{} + type connectInfo struct { p peer.ID spr *record.Envelope @@ -1143,81 +1146,105 @@ func (gs *GossipSubRouter) connector() { } } -func (gs *GossipSubRouter) Publish(msg *Message) { - gs.mcache.Put(msg) - - from := msg.ReceivedFrom - topic := msg.GetTopic() - - tosend := make(map[peer.ID]struct{}) - - // any peers in the topic? - tmap, ok := gs.p.topics[topic] - if !ok { - return +func (gs *GossipSubRouter) PublishBatch(messages []*Message, opts *BatchPublishOptions) { + strategy := opts.Strategy + for _, msg := range messages { + msgID := gs.p.idGen.ID(msg) + for p, rpc := range gs.rpcs(msg) { + strategy.AddRPC(p, msgID, rpc) + } } - if gs.floodPublish && from == gs.p.host.ID() { - for p := range tmap { - _, direct := gs.direct[p] - if direct || gs.score.Score(p) >= gs.publishThreshold { - tosend[p] = struct{}{} - } - } - } else { - // direct peers - for p := range gs.direct { - _, inTopic := tmap[p] - if inTopic { - tosend[p] = struct{}{} - } - } + for p, rpc := range strategy.All() { + gs.sendRPC(p, rpc, false) + } +} - // floodsub peers - for p := range tmap { - if !gs.feature(GossipSubFeatureMesh, gs.peers[p]) && gs.score.Score(p) >= gs.publishThreshold { - tosend[p] = struct{}{} - } - } +func (gs *GossipSubRouter) Publish(msg *Message) { + for p, rpc := range gs.rpcs(msg) { + gs.sendRPC(p, rpc, false) + } +} - // gossipsub peers - gmap, ok := gs.mesh[topic] +func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] { + return func(yield func(peer.ID, *RPC) bool) { + gs.mcache.Put(msg) + + from := msg.ReceivedFrom + topic := msg.GetTopic() + + tosend := make(map[peer.ID]struct{}) + + // any peers in the topic? + tmap, ok := gs.p.topics[topic] if !ok { - // we are not in the mesh for topic, use fanout peers - gmap, ok = gs.fanout[topic] - if !ok || len(gmap) == 0 { - // we don't have any, pick some with score above the publish threshold - peers := gs.getPeers(topic, gs.params.D, func(p peer.ID) bool { - _, direct := gs.direct[p] - return !direct && gs.score.Score(p) >= gs.publishThreshold - }) + return + } - if len(peers) > 0 { - gmap = peerListToMap(peers) - gs.fanout[topic] = gmap + if gs.floodPublish && from == gs.p.host.ID() { + for p := range tmap { + _, direct := gs.direct[p] + if direct || gs.score.Score(p) >= gs.publishThreshold { + tosend[p] = struct{}{} } } - gs.lastpub[topic] = time.Now().UnixNano() + } else { + // direct peers + for p := range gs.direct { + _, inTopic := tmap[p] + if inTopic { + tosend[p] = struct{}{} + } + } + + // floodsub peers + for p := range tmap { + if !gs.feature(GossipSubFeatureMesh, gs.peers[p]) && gs.score.Score(p) >= gs.publishThreshold { + tosend[p] = struct{}{} + } + } + + // gossipsub peers + gmap, ok := gs.mesh[topic] + if !ok { + // we are not in the mesh for topic, use fanout peers + gmap, ok = gs.fanout[topic] + if !ok || len(gmap) == 0 { + // we don't have any, pick some with score above the publish threshold + peers := gs.getPeers(topic, gs.params.D, func(p peer.ID) bool { + _, direct := gs.direct[p] + return !direct && gs.score.Score(p) >= gs.publishThreshold + }) + + if len(peers) > 0 { + gmap = peerListToMap(peers) + gs.fanout[topic] = gmap + } + } + gs.lastpub[topic] = time.Now().UnixNano() + } + + csum := computeChecksum(gs.p.idGen.ID(msg)) + for p := range gmap { + // Check if it has already received an IDONTWANT for the message. + // If so, don't send it to the peer + if _, ok := gs.unwanted[p][csum]; ok { + continue + } + tosend[p] = struct{}{} + } } - csum := computeChecksum(gs.p.idGen.ID(msg)) - for p := range gmap { - // Check if it has already received an IDONTWANT for the message. - // If so, don't send it to the peer - if _, ok := gs.unwanted[p][csum]; ok { + out := rpcWithMessages(msg.Message) + for pid := range tosend { + if pid == from || pid == peer.ID(msg.GetFrom()) { continue } - tosend[p] = struct{}{} - } - } - out := rpcWithMessages(msg.Message) - for pid := range tosend { - if pid == from || pid == peer.ID(msg.GetFrom()) { - continue + if !yield(pid, out) { + return + } } - - gs.sendRPC(pid, out, false) } } diff --git a/gossipsub_test.go b/gossipsub_test.go index abb347f..72188be 100644 --- a/gossipsub_test.go +++ b/gossipsub_test.go @@ -9,9 +9,11 @@ import ( "io" mrand "math/rand" "sort" + "strings" "sync" "sync/atomic" "testing" + "testing/quick" "time" pb "github.com/libp2p/go-libp2p-pubsub/pb" @@ -3406,3 +3408,209 @@ func BenchmarkAllocDoDropRPC(b *testing.B) { gs.doDropRPC(&RPC{}, "peerID", "reason") } } + +func TestRoundRobinMessageIDScheduler(t *testing.T) { + const maxNumPeers = 256 + const maxNumMessages = 1_000 + + err := quick.Check(func(numPeers uint16, numMessages uint16) bool { + numPeers = numPeers % maxNumPeers + numMessages = numMessages % maxNumMessages + + output := make([]pendingRPC, 0, numMessages*numPeers) + + var strategy RoundRobinMessageIDScheduler + + peers := make([]peer.ID, numPeers) + for i := 0; i < int(numPeers); i++ { + peers[i] = peer.ID(fmt.Sprintf("peer%d", i)) + } + + getID := func(r pendingRPC) string { + return string(r.rpc.Publish[0].Data) + } + + for i := range int(numMessages) { + for j := range int(numPeers) { + strategy.AddRPC(peers[j], fmt.Sprintf("msg%d", i), &RPC{ + RPC: pb.RPC{ + Publish: []*pb.Message{ + { + Data: []byte(fmt.Sprintf("msg%d", i)), + }, + }, + }, + }) + } + } + + for p, rpc := range strategy.All() { + output = append(output, pendingRPC{ + peer: p, + rpc: rpc, + }) + } + + // Check invariants + // 1. The published rpcs count is the same as the number of messages added + // 2. Before all message IDs are seen, no message ID may be repeated + // 3. The set of message ID + peer ID combinations should be the same as the input + + // 1. + expectedCount := int(numMessages) * int(numPeers) + if len(output) != expectedCount { + t.Logf("Expected %d RPCs, got %d", expectedCount, len(output)) + return false + } + + // 2. + seen := make(map[string]bool) + expected := make(map[string]bool) + for i := 0; i < int(numMessages); i++ { + expected[fmt.Sprintf("msg%d", i)] = true + } + + for _, rpc := range output { + if expected[getID(rpc)] { + delete(expected, getID(rpc)) + } + if seen[getID(rpc)] && len(expected) > 0 { + t.Logf("Message ID %s repeated before all message IDs are seen", getID(rpc)) + return false + } + seen[getID(rpc)] = true + } + + // 3. + inputSet := make(map[string]bool) + for i := range int(numMessages) { + for j := range int(numPeers) { + inputSet[fmt.Sprintf("msg%d:peer%d", i, j)] = true + } + } + for _, rpc := range output { + if !inputSet[getID(rpc)+":"+string(rpc.peer)] { + t.Logf("Message ID %s not in input", getID(rpc)) + return false + } + } + return true + }, &quick.Config{MaxCount: 32}) + if err != nil { + t.Fatal(err) + } +} + +func BenchmarkRoundRobinMessageIDScheduler(b *testing.B) { + const numPeers = 1_000 + const numMessages = 1_000 + var strategy RoundRobinMessageIDScheduler + + peers := make([]peer.ID, numPeers) + for i := range int(numPeers) { + peers[i] = peer.ID(fmt.Sprintf("peer%d", i)) + } + msgs := make([]string, numMessages) + for i := range numMessages { + msgs[i] = fmt.Sprintf("msg%d", i) + } + + emptyRPC := &RPC{} + b.ResetTimer() + + for i := 0; i < b.N; i++ { + j := i % len(peers) + msgIdx := i % numMessages + strategy.AddRPC(peers[j], msgs[msgIdx], emptyRPC) + if i%100 == 0 { + for range strategy.All() { + } + } + } +} + +func TestMessageBatchPublish(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hosts := getDefaultHosts(t, 20) + + msgIDFn := func(msg *pb.Message) string { + hdr := string(msg.Data[0:16]) + msgID := strings.SplitN(hdr, " ", 2) + return msgID[0] + } + const numMessages = 100 + // +8 to account for the gossiping overhead + psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(msgIDFn), WithPeerOutboundQueueSize(numMessages+8)) + + var topics []*Topic + var msgs []*Subscription + for _, ps := range psubs { + topic, err := ps.Join("foobar") + if err != nil { + t.Fatal(err) + } + topics = append(topics, topic) + + subch, err := topic.Subscribe(WithBufferSize(numMessages + 8)) + if err != nil { + t.Fatal(err) + } + + msgs = append(msgs, subch) + } + + sparseConnect(t, hosts) + + // wait for heartbeats to build mesh + time.Sleep(time.Second * 2) + + var batch MessageBatch + for i := 0; i < numMessages; i++ { + msg := []byte(fmt.Sprintf("%d it's not a floooooood %d", i, i)) + err := topics[0].AddToBatch(ctx, &batch, msg) + if err != nil { + t.Fatal(err) + } + } + err := psubs[0].PublishBatch(&batch) + if err != nil { + t.Fatal(err) + } + + for range numMessages { + for _, sub := range msgs { + got, err := sub.Next(ctx) + if err != nil { + t.Fatal(sub.err) + } + id := msgIDFn(got.Message) + expected := []byte(fmt.Sprintf("%s it's not a floooooood %s", id, id)) + if !bytes.Equal(expected, got.Data) { + t.Fatal("got wrong message!") + } + } + } +} + +func TestPublishDuplicateMessage(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hosts := getDefaultHosts(t, 1) + psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(func(msg *pb.Message) string { + return string(msg.Data) + })) + topic, err := psubs[0].Join("foobar") + if err != nil { + t.Fatal(err) + } + err = topic.Publish(ctx, []byte("hello")) + if err != nil { + t.Fatal(err) + } + + err = topic.Publish(ctx, []byte("hello")) + if err != nil { + t.Fatal("Duplicate message should not return an error") + } +} diff --git a/messagebatch.go b/messagebatch.go new file mode 100644 index 0000000..8178645 --- /dev/null +++ b/messagebatch.go @@ -0,0 +1,62 @@ +package pubsub + +import ( + "iter" + + "github.com/libp2p/go-libp2p/core/peer" +) + +// MessageBatch allows a user to batch related messages and then publish them at +// once. This allows the Scheduler to define an order for outgoing RPCs. +// This helps bandwidth constrained peers. +type MessageBatch struct { + messages []*Message +} + +type messageBatchAndPublishOptions struct { + messages []*Message + opts *BatchPublishOptions +} + +// RPCScheduler schedules outgoing RPCs. +type RPCScheduler interface { + // AddRPC adds an RPC to the scheduler. + AddRPC(peer peer.ID, msgID string, rpc *RPC) + // All returns an ordered iterator of RPCs. + All() iter.Seq2[peer.ID, *RPC] +} + +type pendingRPC struct { + peer peer.ID + rpc *RPC +} + +// RoundRobinMessageIDScheduler schedules outgoing RPCs in round-robin order of message IDs. +type RoundRobinMessageIDScheduler struct { + rpcs map[string][]pendingRPC +} + +func (s *RoundRobinMessageIDScheduler) AddRPC(peer peer.ID, msgID string, rpc *RPC) { + if s.rpcs == nil { + s.rpcs = make(map[string][]pendingRPC) + } + s.rpcs[msgID] = append(s.rpcs[msgID], pendingRPC{peer: peer, rpc: rpc}) +} + +func (s *RoundRobinMessageIDScheduler) All() iter.Seq2[peer.ID, *RPC] { + return func(yield func(peer.ID, *RPC) bool) { + for len(s.rpcs) > 0 { + for msgID, rpcs := range s.rpcs { + if len(rpcs) == 0 { + delete(s.rpcs, msgID) + continue + } + if !yield(rpcs[0].peer, rpcs[0].rpc) { + return + } + + s.rpcs[msgID] = rpcs[1:] + } + } + } +} diff --git a/pubsub.go b/pubsub.go index 5c27c3e..fae115a 100644 --- a/pubsub.go +++ b/pubsub.go @@ -134,6 +134,9 @@ type PubSub struct { // sendMsg handles messages that have been validated sendMsg chan *Message + // sendMessageBatch publishes a batch of messages + sendMessageBatch chan messageBatchAndPublishOptions + // addVal handles validator registration requests addVal chan *addValReq @@ -217,6 +220,10 @@ type PubSubRouter interface { Leave(topic string) } +type BatchPublisher interface { + PublishBatch(messages []*Message, opts *BatchPublishOptions) +} + type AcceptStatus int const ( @@ -281,6 +288,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option rmTopic: make(chan *rmTopicReq), getTopics: make(chan *topicReq), sendMsg: make(chan *Message, 32), + sendMessageBatch: make(chan messageBatchAndPublishOptions, 1), addVal: make(chan *addValReq), rmVal: make(chan *rmValReq), eval: make(chan func()), @@ -642,6 +650,9 @@ func (p *PubSub) processLoop(ctx context.Context) { case msg := <-p.sendMsg: p.publishMessage(msg) + case batchAndOpts := <-p.sendMessageBatch: + p.publishMessageBatch(batchAndOpts) + case req := <-p.addVal: p.val.AddValidator(req) @@ -1221,6 +1232,15 @@ func (p *PubSub) publishMessage(msg *Message) { } } +func (p *PubSub) publishMessageBatch(batchAndOpts messageBatchAndPublishOptions) { + for _, msg := range batchAndOpts.messages { + p.tracer.DeliverMessage(msg) + p.notifySubs(msg) + } + // We type checked when pushing the batch to the channel + p.rt.(BatchPublisher).PublishBatch(batchAndOpts.messages, batchAndOpts.opts) +} + type addTopicReq struct { topic *Topic resp chan *Topic @@ -1358,6 +1378,39 @@ func (p *PubSub) Publish(topic string, data []byte, opts ...PubOpt) error { return t.Publish(context.TODO(), data, opts...) } +// PublishBatch publishes a batch of messages. This only works for routers that +// implement the BatchPublisher interface. +// +// Users should make sure there is enough space in the Peer's outbound queue to +// ensure messages are not dropped. WithPeerOutboundQueueSize should be set to +// at least the expected number of batched messages per peer plus some slack to +// account for gossip messages. +// +// The default publish strategy is RoundRobinMessageIDScheduler. +func (p *PubSub) PublishBatch(batch *MessageBatch, opts ...BatchPubOpt) error { + if _, ok := p.rt.(BatchPublisher); !ok { + return fmt.Errorf("pubsub router is not a BatchPublisher") + } + + publishOptions := &BatchPublishOptions{} + for _, o := range opts { + err := o(publishOptions) + if err != nil { + return err + } + } + setDefaultBatchPublishOptions(publishOptions) + + p.sendMessageBatch <- messageBatchAndPublishOptions{ + messages: batch.messages, + opts: publishOptions, + } + + // Clear the batch's messages in case a user reuses the same batch object + batch.messages = nil + return nil +} + func (p *PubSub) nextSeqno() []byte { seqno := make([]byte, 8) counter := atomic.AddUint64(&p.counter, 1) diff --git a/topic.go b/topic.go index f9b7ccc..a6ad979 100644 --- a/topic.go +++ b/topic.go @@ -219,14 +219,53 @@ type PublishOptions struct { validatorData any } +type BatchPublishOptions struct { + Strategy RPCScheduler +} + type PubOpt func(pub *PublishOptions) error +type BatchPubOpt func(pub *BatchPublishOptions) error + +func setDefaultBatchPublishOptions(opts *BatchPublishOptions) { + if opts.Strategy == nil { + opts.Strategy = &RoundRobinMessageIDScheduler{} + } +} // Publish publishes data to topic. func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error { + msg, err := t.validate(ctx, data, opts...) + if err != nil { + if errors.Is(err, dupeErr{}) { + // If it was a duplicate, we return nil to indicate success. + // Semantically the message was published by us or someone else. + return nil + } + return err + } + return t.p.val.sendMsgBlocking(msg) +} + +func (t *Topic) AddToBatch(ctx context.Context, batch *MessageBatch, data []byte, opts ...PubOpt) error { + msg, err := t.validate(ctx, data, opts...) + if err != nil { + if errors.Is(err, dupeErr{}) { + // If it was a duplicate, we return nil to indicate success. + // Semantically the message was published by us or someone else. + // We won't add it to the batch. Since it's already been published. + return nil + } + return err + } + batch.messages = append(batch.messages, msg) + return nil +} + +func (t *Topic) validate(ctx context.Context, data []byte, opts ...PubOpt) (*Message, error) { t.mux.RLock() defer t.mux.RUnlock() if t.closed { - return ErrTopicClosed + return nil, ErrTopicClosed } pid := t.p.signID @@ -236,17 +275,17 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error for _, opt := range opts { err := opt(pub) if err != nil { - return err + return nil, err } } if pub.customKey != nil && !pub.local { key, pid = pub.customKey() if key == nil { - return ErrNilSignKey + return nil, ErrNilSignKey } if len(pid) == 0 { - return ErrEmptyPeerID + return nil, ErrEmptyPeerID } } @@ -264,7 +303,7 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error m.From = []byte(pid) err := signMessage(pid, key, m) if err != nil { - return err + return nil, err } } @@ -291,9 +330,9 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error break readyLoop } case <-t.p.ctx.Done(): - return t.p.ctx.Err() + return nil, t.p.ctx.Err() case <-ctx.Done(): - return ctx.Err() + return nil, ctx.Err() } if ticker == nil { ticker = time.NewTicker(200 * time.Millisecond) @@ -303,13 +342,18 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error select { case <-ticker.C: case <-ctx.Done(): - return fmt.Errorf("router is not ready: %w", ctx.Err()) + return nil, fmt.Errorf("router is not ready: %w", ctx.Err()) } } } } - return t.p.val.PushLocal(&Message{m, "", t.p.host.ID(), pub.validatorData, pub.local}) + msg := &Message{m, "", t.p.host.ID(), pub.validatorData, pub.local} + err := t.p.val.ValidateLocal(msg) + if err != nil { + return nil, err + } + return msg, nil } // WithReadiness returns a publishing option for only publishing when the router is ready. diff --git a/validation.go b/validation.go index 1044d5d..6433a41 100644 --- a/validation.go +++ b/validation.go @@ -26,6 +26,12 @@ func (e ValidationError) Error() string { return e.Reason } +type dupeErr struct{} + +func (dupeErr) Error() string { + return "duplicate message" +} + // Validator is a function that validates a message with a binary decision: accept or reject. type Validator func(context.Context, peer.ID, *Message) bool @@ -226,10 +232,9 @@ func (v *validation) RemoveValidator(req *rmValReq) { } } -// PushLocal synchronously pushes a locally published message and performs applicable -// validations. -// Returns an error if validation fails -func (v *validation) PushLocal(msg *Message) error { +// ValidateLocal synchronously validates a locally published message and +// performs applicable validations. Returns an error if validation fails. +func (v *validation) ValidateLocal(msg *Message) error { v.p.tracer.PublishMessage(msg) err := v.p.checkSigningPolicy(msg) @@ -238,7 +243,9 @@ func (v *validation) PushLocal(msg *Message) error { } vals := v.getValidators(msg) - return v.validate(vals, msg.ReceivedFrom, msg, true) + return v.validate(vals, msg.ReceivedFrom, msg, true, func(msg *Message) error { + return nil + }) } // Push pushes a message into the validation pipeline. @@ -282,15 +289,26 @@ func (v *validation) validateWorker() { for { select { case req := <-v.validateQ: - v.validate(req.vals, req.src, req.msg, false) + _ = v.validate(req.vals, req.src, req.msg, false, v.sendMsgBlocking) case <-v.p.ctx.Done(): return } } } -// validate performs validation and only sends the message if all validators succeed -func (v *validation) validate(vals []*validatorImpl, src peer.ID, msg *Message, synchronous bool) error { +func (v *validation) sendMsgBlocking(msg *Message) error { + select { + case v.p.sendMsg <- msg: + return nil + case <-v.p.ctx.Done(): + return v.p.ctx.Err() + } +} + +// validate performs validation and only calls onValid if all validators succeed. +// If synchronous is true, onValid will be called before this function returns +// if the message is new and accepted. +func (v *validation) validate(vals []*validatorImpl, src peer.ID, msg *Message, synchronous bool, onValid func(*Message) error) error { // If signature verification is enabled, but signing is disabled, // the Signature is required to be nil upon receiving the message in PubSub.pushMsg. if msg.Signature != nil { @@ -306,7 +324,7 @@ func (v *validation) validate(vals []*validatorImpl, src peer.ID, msg *Message, id := v.p.idGen.ID(msg) if !v.p.markSeen(id) { v.tracer.DuplicateMessage(msg) - return nil + return dupeErr{} } else { v.tracer.ValidateMessage(msg) } @@ -345,7 +363,7 @@ loop: select { case v.validateThrottle <- struct{}{}: go func() { - v.doValidateTopic(async, src, msg, result) + v.doValidateTopic(async, src, msg, result, onValid) <-v.validateThrottle }() default: @@ -360,13 +378,8 @@ loop: return ValidationError{Reason: RejectValidationIgnored} } - // no async validators, accepted message, send it! - select { - case v.p.sendMsg <- msg: - return nil - case <-v.p.ctx.Done(): - return v.p.ctx.Err() - } + // no async validators, accepted message + return onValid(msg) } func (v *validation) validateSignature(msg *Message) bool { @@ -379,7 +392,7 @@ func (v *validation) validateSignature(msg *Message) bool { return true } -func (v *validation) doValidateTopic(vals []*validatorImpl, src peer.ID, msg *Message, r ValidationResult) { +func (v *validation) doValidateTopic(vals []*validatorImpl, src peer.ID, msg *Message, r ValidationResult, onValid func(*Message) error) { result := v.validateTopic(vals, src, msg) if result == ValidationAccept && r != ValidationAccept { @@ -388,7 +401,7 @@ func (v *validation) doValidateTopic(vals []*validatorImpl, src peer.ID, msg *Me switch result { case ValidationAccept: - v.p.sendMsg <- msg + _ = onValid(msg) case ValidationReject: log.Debugf("message validation failed; dropping message from %s", src) v.tracer.RejectMessage(msg, RejectValidationFailed)