mirror of
https://github.com/logos-messaging/go-libp2p-pubsub.git
synced 2026-01-02 04:43:10 +00:00
feat(gossipsub): Add MessageBatch (#607)
to support batch publishing messages Replaces #602. Batch publishing lets the system know there are multiple related messages to be published so it can prioritize sending different messages before sending copies of messages. For example, with the default API, when you publish two messages A and B, under the hood A gets sent to D=8 peers first, before B gets sent out. With this MessageBatch api we can now send one copy of A _and then_ one copy of B before sending multiple copies. When a node has bandwidth constraints relative to the messages it is publishing this improves dissemination time. For more context see this post: https://ethresear.ch/t/improving-das-performance-with-gossipsub-batch-publishing/21713
This commit is contained in:
parent
50ccc5ca90
commit
0c5ee7bbfe
149
gossipsub.go
149
gossipsub.go
@ -5,6 +5,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"time"
|
||||
@ -522,6 +523,8 @@ type GossipSubRouter struct {
|
||||
heartbeatTicks uint64
|
||||
}
|
||||
|
||||
var _ BatchPublisher = &GossipSubRouter{}
|
||||
|
||||
type connectInfo struct {
|
||||
p peer.ID
|
||||
spr *record.Envelope
|
||||
@ -1143,81 +1146,105 @@ func (gs *GossipSubRouter) connector() {
|
||||
}
|
||||
}
|
||||
|
||||
func (gs *GossipSubRouter) Publish(msg *Message) {
|
||||
gs.mcache.Put(msg)
|
||||
|
||||
from := msg.ReceivedFrom
|
||||
topic := msg.GetTopic()
|
||||
|
||||
tosend := make(map[peer.ID]struct{})
|
||||
|
||||
// any peers in the topic?
|
||||
tmap, ok := gs.p.topics[topic]
|
||||
if !ok {
|
||||
return
|
||||
func (gs *GossipSubRouter) PublishBatch(messages []*Message, opts *BatchPublishOptions) {
|
||||
strategy := opts.Strategy
|
||||
for _, msg := range messages {
|
||||
msgID := gs.p.idGen.ID(msg)
|
||||
for p, rpc := range gs.rpcs(msg) {
|
||||
strategy.AddRPC(p, msgID, rpc)
|
||||
}
|
||||
}
|
||||
|
||||
if gs.floodPublish && from == gs.p.host.ID() {
|
||||
for p := range tmap {
|
||||
_, direct := gs.direct[p]
|
||||
if direct || gs.score.Score(p) >= gs.publishThreshold {
|
||||
tosend[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// direct peers
|
||||
for p := range gs.direct {
|
||||
_, inTopic := tmap[p]
|
||||
if inTopic {
|
||||
tosend[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
for p, rpc := range strategy.All() {
|
||||
gs.sendRPC(p, rpc, false)
|
||||
}
|
||||
}
|
||||
|
||||
// floodsub peers
|
||||
for p := range tmap {
|
||||
if !gs.feature(GossipSubFeatureMesh, gs.peers[p]) && gs.score.Score(p) >= gs.publishThreshold {
|
||||
tosend[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
func (gs *GossipSubRouter) Publish(msg *Message) {
|
||||
for p, rpc := range gs.rpcs(msg) {
|
||||
gs.sendRPC(p, rpc, false)
|
||||
}
|
||||
}
|
||||
|
||||
// gossipsub peers
|
||||
gmap, ok := gs.mesh[topic]
|
||||
func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] {
|
||||
return func(yield func(peer.ID, *RPC) bool) {
|
||||
gs.mcache.Put(msg)
|
||||
|
||||
from := msg.ReceivedFrom
|
||||
topic := msg.GetTopic()
|
||||
|
||||
tosend := make(map[peer.ID]struct{})
|
||||
|
||||
// any peers in the topic?
|
||||
tmap, ok := gs.p.topics[topic]
|
||||
if !ok {
|
||||
// we are not in the mesh for topic, use fanout peers
|
||||
gmap, ok = gs.fanout[topic]
|
||||
if !ok || len(gmap) == 0 {
|
||||
// we don't have any, pick some with score above the publish threshold
|
||||
peers := gs.getPeers(topic, gs.params.D, func(p peer.ID) bool {
|
||||
_, direct := gs.direct[p]
|
||||
return !direct && gs.score.Score(p) >= gs.publishThreshold
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if len(peers) > 0 {
|
||||
gmap = peerListToMap(peers)
|
||||
gs.fanout[topic] = gmap
|
||||
if gs.floodPublish && from == gs.p.host.ID() {
|
||||
for p := range tmap {
|
||||
_, direct := gs.direct[p]
|
||||
if direct || gs.score.Score(p) >= gs.publishThreshold {
|
||||
tosend[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
gs.lastpub[topic] = time.Now().UnixNano()
|
||||
} else {
|
||||
// direct peers
|
||||
for p := range gs.direct {
|
||||
_, inTopic := tmap[p]
|
||||
if inTopic {
|
||||
tosend[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// floodsub peers
|
||||
for p := range tmap {
|
||||
if !gs.feature(GossipSubFeatureMesh, gs.peers[p]) && gs.score.Score(p) >= gs.publishThreshold {
|
||||
tosend[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// gossipsub peers
|
||||
gmap, ok := gs.mesh[topic]
|
||||
if !ok {
|
||||
// we are not in the mesh for topic, use fanout peers
|
||||
gmap, ok = gs.fanout[topic]
|
||||
if !ok || len(gmap) == 0 {
|
||||
// we don't have any, pick some with score above the publish threshold
|
||||
peers := gs.getPeers(topic, gs.params.D, func(p peer.ID) bool {
|
||||
_, direct := gs.direct[p]
|
||||
return !direct && gs.score.Score(p) >= gs.publishThreshold
|
||||
})
|
||||
|
||||
if len(peers) > 0 {
|
||||
gmap = peerListToMap(peers)
|
||||
gs.fanout[topic] = gmap
|
||||
}
|
||||
}
|
||||
gs.lastpub[topic] = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
csum := computeChecksum(gs.p.idGen.ID(msg))
|
||||
for p := range gmap {
|
||||
// Check if it has already received an IDONTWANT for the message.
|
||||
// If so, don't send it to the peer
|
||||
if _, ok := gs.unwanted[p][csum]; ok {
|
||||
continue
|
||||
}
|
||||
tosend[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
csum := computeChecksum(gs.p.idGen.ID(msg))
|
||||
for p := range gmap {
|
||||
// Check if it has already received an IDONTWANT for the message.
|
||||
// If so, don't send it to the peer
|
||||
if _, ok := gs.unwanted[p][csum]; ok {
|
||||
out := rpcWithMessages(msg.Message)
|
||||
for pid := range tosend {
|
||||
if pid == from || pid == peer.ID(msg.GetFrom()) {
|
||||
continue
|
||||
}
|
||||
tosend[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
out := rpcWithMessages(msg.Message)
|
||||
for pid := range tosend {
|
||||
if pid == from || pid == peer.ID(msg.GetFrom()) {
|
||||
continue
|
||||
if !yield(pid, out) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
gs.sendRPC(pid, out, false)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -9,9 +9,11 @@ import (
|
||||
"io"
|
||||
mrand "math/rand"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
pb "github.com/libp2p/go-libp2p-pubsub/pb"
|
||||
@ -3406,3 +3408,209 @@ func BenchmarkAllocDoDropRPC(b *testing.B) {
|
||||
gs.doDropRPC(&RPC{}, "peerID", "reason")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinMessageIDScheduler(t *testing.T) {
|
||||
const maxNumPeers = 256
|
||||
const maxNumMessages = 1_000
|
||||
|
||||
err := quick.Check(func(numPeers uint16, numMessages uint16) bool {
|
||||
numPeers = numPeers % maxNumPeers
|
||||
numMessages = numMessages % maxNumMessages
|
||||
|
||||
output := make([]pendingRPC, 0, numMessages*numPeers)
|
||||
|
||||
var strategy RoundRobinMessageIDScheduler
|
||||
|
||||
peers := make([]peer.ID, numPeers)
|
||||
for i := 0; i < int(numPeers); i++ {
|
||||
peers[i] = peer.ID(fmt.Sprintf("peer%d", i))
|
||||
}
|
||||
|
||||
getID := func(r pendingRPC) string {
|
||||
return string(r.rpc.Publish[0].Data)
|
||||
}
|
||||
|
||||
for i := range int(numMessages) {
|
||||
for j := range int(numPeers) {
|
||||
strategy.AddRPC(peers[j], fmt.Sprintf("msg%d", i), &RPC{
|
||||
RPC: pb.RPC{
|
||||
Publish: []*pb.Message{
|
||||
{
|
||||
Data: []byte(fmt.Sprintf("msg%d", i)),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for p, rpc := range strategy.All() {
|
||||
output = append(output, pendingRPC{
|
||||
peer: p,
|
||||
rpc: rpc,
|
||||
})
|
||||
}
|
||||
|
||||
// Check invariants
|
||||
// 1. The published rpcs count is the same as the number of messages added
|
||||
// 2. Before all message IDs are seen, no message ID may be repeated
|
||||
// 3. The set of message ID + peer ID combinations should be the same as the input
|
||||
|
||||
// 1.
|
||||
expectedCount := int(numMessages) * int(numPeers)
|
||||
if len(output) != expectedCount {
|
||||
t.Logf("Expected %d RPCs, got %d", expectedCount, len(output))
|
||||
return false
|
||||
}
|
||||
|
||||
// 2.
|
||||
seen := make(map[string]bool)
|
||||
expected := make(map[string]bool)
|
||||
for i := 0; i < int(numMessages); i++ {
|
||||
expected[fmt.Sprintf("msg%d", i)] = true
|
||||
}
|
||||
|
||||
for _, rpc := range output {
|
||||
if expected[getID(rpc)] {
|
||||
delete(expected, getID(rpc))
|
||||
}
|
||||
if seen[getID(rpc)] && len(expected) > 0 {
|
||||
t.Logf("Message ID %s repeated before all message IDs are seen", getID(rpc))
|
||||
return false
|
||||
}
|
||||
seen[getID(rpc)] = true
|
||||
}
|
||||
|
||||
// 3.
|
||||
inputSet := make(map[string]bool)
|
||||
for i := range int(numMessages) {
|
||||
for j := range int(numPeers) {
|
||||
inputSet[fmt.Sprintf("msg%d:peer%d", i, j)] = true
|
||||
}
|
||||
}
|
||||
for _, rpc := range output {
|
||||
if !inputSet[getID(rpc)+":"+string(rpc.peer)] {
|
||||
t.Logf("Message ID %s not in input", getID(rpc))
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}, &quick.Config{MaxCount: 32})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRoundRobinMessageIDScheduler(b *testing.B) {
|
||||
const numPeers = 1_000
|
||||
const numMessages = 1_000
|
||||
var strategy RoundRobinMessageIDScheduler
|
||||
|
||||
peers := make([]peer.ID, numPeers)
|
||||
for i := range int(numPeers) {
|
||||
peers[i] = peer.ID(fmt.Sprintf("peer%d", i))
|
||||
}
|
||||
msgs := make([]string, numMessages)
|
||||
for i := range numMessages {
|
||||
msgs[i] = fmt.Sprintf("msg%d", i)
|
||||
}
|
||||
|
||||
emptyRPC := &RPC{}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
j := i % len(peers)
|
||||
msgIdx := i % numMessages
|
||||
strategy.AddRPC(peers[j], msgs[msgIdx], emptyRPC)
|
||||
if i%100 == 0 {
|
||||
for range strategy.All() {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBatchPublish(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hosts := getDefaultHosts(t, 20)
|
||||
|
||||
msgIDFn := func(msg *pb.Message) string {
|
||||
hdr := string(msg.Data[0:16])
|
||||
msgID := strings.SplitN(hdr, " ", 2)
|
||||
return msgID[0]
|
||||
}
|
||||
const numMessages = 100
|
||||
// +8 to account for the gossiping overhead
|
||||
psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(msgIDFn), WithPeerOutboundQueueSize(numMessages+8))
|
||||
|
||||
var topics []*Topic
|
||||
var msgs []*Subscription
|
||||
for _, ps := range psubs {
|
||||
topic, err := ps.Join("foobar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
topics = append(topics, topic)
|
||||
|
||||
subch, err := topic.Subscribe(WithBufferSize(numMessages + 8))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msgs = append(msgs, subch)
|
||||
}
|
||||
|
||||
sparseConnect(t, hosts)
|
||||
|
||||
// wait for heartbeats to build mesh
|
||||
time.Sleep(time.Second * 2)
|
||||
|
||||
var batch MessageBatch
|
||||
for i := 0; i < numMessages; i++ {
|
||||
msg := []byte(fmt.Sprintf("%d it's not a floooooood %d", i, i))
|
||||
err := topics[0].AddToBatch(ctx, &batch, msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
err := psubs[0].PublishBatch(&batch)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for range numMessages {
|
||||
for _, sub := range msgs {
|
||||
got, err := sub.Next(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(sub.err)
|
||||
}
|
||||
id := msgIDFn(got.Message)
|
||||
expected := []byte(fmt.Sprintf("%s it's not a floooooood %s", id, id))
|
||||
if !bytes.Equal(expected, got.Data) {
|
||||
t.Fatal("got wrong message!")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishDuplicateMessage(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hosts := getDefaultHosts(t, 1)
|
||||
psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(func(msg *pb.Message) string {
|
||||
return string(msg.Data)
|
||||
}))
|
||||
topic, err := psubs[0].Join("foobar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = topic.Publish(ctx, []byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = topic.Publish(ctx, []byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatal("Duplicate message should not return an error")
|
||||
}
|
||||
}
|
||||
|
||||
62
messagebatch.go
Normal file
62
messagebatch.go
Normal file
@ -0,0 +1,62 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"iter"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
)
|
||||
|
||||
// MessageBatch allows a user to batch related messages and then publish them at
|
||||
// once. This allows the Scheduler to define an order for outgoing RPCs.
|
||||
// This helps bandwidth constrained peers.
|
||||
type MessageBatch struct {
|
||||
messages []*Message
|
||||
}
|
||||
|
||||
type messageBatchAndPublishOptions struct {
|
||||
messages []*Message
|
||||
opts *BatchPublishOptions
|
||||
}
|
||||
|
||||
// RPCScheduler schedules outgoing RPCs.
|
||||
type RPCScheduler interface {
|
||||
// AddRPC adds an RPC to the scheduler.
|
||||
AddRPC(peer peer.ID, msgID string, rpc *RPC)
|
||||
// All returns an ordered iterator of RPCs.
|
||||
All() iter.Seq2[peer.ID, *RPC]
|
||||
}
|
||||
|
||||
type pendingRPC struct {
|
||||
peer peer.ID
|
||||
rpc *RPC
|
||||
}
|
||||
|
||||
// RoundRobinMessageIDScheduler schedules outgoing RPCs in round-robin order of message IDs.
|
||||
type RoundRobinMessageIDScheduler struct {
|
||||
rpcs map[string][]pendingRPC
|
||||
}
|
||||
|
||||
func (s *RoundRobinMessageIDScheduler) AddRPC(peer peer.ID, msgID string, rpc *RPC) {
|
||||
if s.rpcs == nil {
|
||||
s.rpcs = make(map[string][]pendingRPC)
|
||||
}
|
||||
s.rpcs[msgID] = append(s.rpcs[msgID], pendingRPC{peer: peer, rpc: rpc})
|
||||
}
|
||||
|
||||
func (s *RoundRobinMessageIDScheduler) All() iter.Seq2[peer.ID, *RPC] {
|
||||
return func(yield func(peer.ID, *RPC) bool) {
|
||||
for len(s.rpcs) > 0 {
|
||||
for msgID, rpcs := range s.rpcs {
|
||||
if len(rpcs) == 0 {
|
||||
delete(s.rpcs, msgID)
|
||||
continue
|
||||
}
|
||||
if !yield(rpcs[0].peer, rpcs[0].rpc) {
|
||||
return
|
||||
}
|
||||
|
||||
s.rpcs[msgID] = rpcs[1:]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
53
pubsub.go
53
pubsub.go
@ -134,6 +134,9 @@ type PubSub struct {
|
||||
// sendMsg handles messages that have been validated
|
||||
sendMsg chan *Message
|
||||
|
||||
// sendMessageBatch publishes a batch of messages
|
||||
sendMessageBatch chan messageBatchAndPublishOptions
|
||||
|
||||
// addVal handles validator registration requests
|
||||
addVal chan *addValReq
|
||||
|
||||
@ -217,6 +220,10 @@ type PubSubRouter interface {
|
||||
Leave(topic string)
|
||||
}
|
||||
|
||||
type BatchPublisher interface {
|
||||
PublishBatch(messages []*Message, opts *BatchPublishOptions)
|
||||
}
|
||||
|
||||
type AcceptStatus int
|
||||
|
||||
const (
|
||||
@ -281,6 +288,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option
|
||||
rmTopic: make(chan *rmTopicReq),
|
||||
getTopics: make(chan *topicReq),
|
||||
sendMsg: make(chan *Message, 32),
|
||||
sendMessageBatch: make(chan messageBatchAndPublishOptions, 1),
|
||||
addVal: make(chan *addValReq),
|
||||
rmVal: make(chan *rmValReq),
|
||||
eval: make(chan func()),
|
||||
@ -642,6 +650,9 @@ func (p *PubSub) processLoop(ctx context.Context) {
|
||||
case msg := <-p.sendMsg:
|
||||
p.publishMessage(msg)
|
||||
|
||||
case batchAndOpts := <-p.sendMessageBatch:
|
||||
p.publishMessageBatch(batchAndOpts)
|
||||
|
||||
case req := <-p.addVal:
|
||||
p.val.AddValidator(req)
|
||||
|
||||
@ -1221,6 +1232,15 @@ func (p *PubSub) publishMessage(msg *Message) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSub) publishMessageBatch(batchAndOpts messageBatchAndPublishOptions) {
|
||||
for _, msg := range batchAndOpts.messages {
|
||||
p.tracer.DeliverMessage(msg)
|
||||
p.notifySubs(msg)
|
||||
}
|
||||
// We type checked when pushing the batch to the channel
|
||||
p.rt.(BatchPublisher).PublishBatch(batchAndOpts.messages, batchAndOpts.opts)
|
||||
}
|
||||
|
||||
type addTopicReq struct {
|
||||
topic *Topic
|
||||
resp chan *Topic
|
||||
@ -1358,6 +1378,39 @@ func (p *PubSub) Publish(topic string, data []byte, opts ...PubOpt) error {
|
||||
return t.Publish(context.TODO(), data, opts...)
|
||||
}
|
||||
|
||||
// PublishBatch publishes a batch of messages. This only works for routers that
|
||||
// implement the BatchPublisher interface.
|
||||
//
|
||||
// Users should make sure there is enough space in the Peer's outbound queue to
|
||||
// ensure messages are not dropped. WithPeerOutboundQueueSize should be set to
|
||||
// at least the expected number of batched messages per peer plus some slack to
|
||||
// account for gossip messages.
|
||||
//
|
||||
// The default publish strategy is RoundRobinMessageIDScheduler.
|
||||
func (p *PubSub) PublishBatch(batch *MessageBatch, opts ...BatchPubOpt) error {
|
||||
if _, ok := p.rt.(BatchPublisher); !ok {
|
||||
return fmt.Errorf("pubsub router is not a BatchPublisher")
|
||||
}
|
||||
|
||||
publishOptions := &BatchPublishOptions{}
|
||||
for _, o := range opts {
|
||||
err := o(publishOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
setDefaultBatchPublishOptions(publishOptions)
|
||||
|
||||
p.sendMessageBatch <- messageBatchAndPublishOptions{
|
||||
messages: batch.messages,
|
||||
opts: publishOptions,
|
||||
}
|
||||
|
||||
// Clear the batch's messages in case a user reuses the same batch object
|
||||
batch.messages = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PubSub) nextSeqno() []byte {
|
||||
seqno := make([]byte, 8)
|
||||
counter := atomic.AddUint64(&p.counter, 1)
|
||||
|
||||
62
topic.go
62
topic.go
@ -219,14 +219,53 @@ type PublishOptions struct {
|
||||
validatorData any
|
||||
}
|
||||
|
||||
type BatchPublishOptions struct {
|
||||
Strategy RPCScheduler
|
||||
}
|
||||
|
||||
type PubOpt func(pub *PublishOptions) error
|
||||
type BatchPubOpt func(pub *BatchPublishOptions) error
|
||||
|
||||
func setDefaultBatchPublishOptions(opts *BatchPublishOptions) {
|
||||
if opts.Strategy == nil {
|
||||
opts.Strategy = &RoundRobinMessageIDScheduler{}
|
||||
}
|
||||
}
|
||||
|
||||
// Publish publishes data to topic.
|
||||
func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error {
|
||||
msg, err := t.validate(ctx, data, opts...)
|
||||
if err != nil {
|
||||
if errors.Is(err, dupeErr{}) {
|
||||
// If it was a duplicate, we return nil to indicate success.
|
||||
// Semantically the message was published by us or someone else.
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return t.p.val.sendMsgBlocking(msg)
|
||||
}
|
||||
|
||||
func (t *Topic) AddToBatch(ctx context.Context, batch *MessageBatch, data []byte, opts ...PubOpt) error {
|
||||
msg, err := t.validate(ctx, data, opts...)
|
||||
if err != nil {
|
||||
if errors.Is(err, dupeErr{}) {
|
||||
// If it was a duplicate, we return nil to indicate success.
|
||||
// Semantically the message was published by us or someone else.
|
||||
// We won't add it to the batch. Since it's already been published.
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
batch.messages = append(batch.messages, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Topic) validate(ctx context.Context, data []byte, opts ...PubOpt) (*Message, error) {
|
||||
t.mux.RLock()
|
||||
defer t.mux.RUnlock()
|
||||
if t.closed {
|
||||
return ErrTopicClosed
|
||||
return nil, ErrTopicClosed
|
||||
}
|
||||
|
||||
pid := t.p.signID
|
||||
@ -236,17 +275,17 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error
|
||||
for _, opt := range opts {
|
||||
err := opt(pub)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if pub.customKey != nil && !pub.local {
|
||||
key, pid = pub.customKey()
|
||||
if key == nil {
|
||||
return ErrNilSignKey
|
||||
return nil, ErrNilSignKey
|
||||
}
|
||||
if len(pid) == 0 {
|
||||
return ErrEmptyPeerID
|
||||
return nil, ErrEmptyPeerID
|
||||
}
|
||||
}
|
||||
|
||||
@ -264,7 +303,7 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error
|
||||
m.From = []byte(pid)
|
||||
err := signMessage(pid, key, m)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -291,9 +330,9 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error
|
||||
break readyLoop
|
||||
}
|
||||
case <-t.p.ctx.Done():
|
||||
return t.p.ctx.Err()
|
||||
return nil, t.p.ctx.Err()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
if ticker == nil {
|
||||
ticker = time.NewTicker(200 * time.Millisecond)
|
||||
@ -303,13 +342,18 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("router is not ready: %w", ctx.Err())
|
||||
return nil, fmt.Errorf("router is not ready: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return t.p.val.PushLocal(&Message{m, "", t.p.host.ID(), pub.validatorData, pub.local})
|
||||
msg := &Message{m, "", t.p.host.ID(), pub.validatorData, pub.local}
|
||||
err := t.p.val.ValidateLocal(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// WithReadiness returns a publishing option for only publishing when the router is ready.
|
||||
|
||||
@ -26,6 +26,12 @@ func (e ValidationError) Error() string {
|
||||
return e.Reason
|
||||
}
|
||||
|
||||
type dupeErr struct{}
|
||||
|
||||
func (dupeErr) Error() string {
|
||||
return "duplicate message"
|
||||
}
|
||||
|
||||
// Validator is a function that validates a message with a binary decision: accept or reject.
|
||||
type Validator func(context.Context, peer.ID, *Message) bool
|
||||
|
||||
@ -226,10 +232,9 @@ func (v *validation) RemoveValidator(req *rmValReq) {
|
||||
}
|
||||
}
|
||||
|
||||
// PushLocal synchronously pushes a locally published message and performs applicable
|
||||
// validations.
|
||||
// Returns an error if validation fails
|
||||
func (v *validation) PushLocal(msg *Message) error {
|
||||
// ValidateLocal synchronously validates a locally published message and
|
||||
// performs applicable validations. Returns an error if validation fails.
|
||||
func (v *validation) ValidateLocal(msg *Message) error {
|
||||
v.p.tracer.PublishMessage(msg)
|
||||
|
||||
err := v.p.checkSigningPolicy(msg)
|
||||
@ -238,7 +243,9 @@ func (v *validation) PushLocal(msg *Message) error {
|
||||
}
|
||||
|
||||
vals := v.getValidators(msg)
|
||||
return v.validate(vals, msg.ReceivedFrom, msg, true)
|
||||
return v.validate(vals, msg.ReceivedFrom, msg, true, func(msg *Message) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Push pushes a message into the validation pipeline.
|
||||
@ -282,15 +289,26 @@ func (v *validation) validateWorker() {
|
||||
for {
|
||||
select {
|
||||
case req := <-v.validateQ:
|
||||
v.validate(req.vals, req.src, req.msg, false)
|
||||
_ = v.validate(req.vals, req.src, req.msg, false, v.sendMsgBlocking)
|
||||
case <-v.p.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validate performs validation and only sends the message if all validators succeed
|
||||
func (v *validation) validate(vals []*validatorImpl, src peer.ID, msg *Message, synchronous bool) error {
|
||||
func (v *validation) sendMsgBlocking(msg *Message) error {
|
||||
select {
|
||||
case v.p.sendMsg <- msg:
|
||||
return nil
|
||||
case <-v.p.ctx.Done():
|
||||
return v.p.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// validate performs validation and only calls onValid if all validators succeed.
|
||||
// If synchronous is true, onValid will be called before this function returns
|
||||
// if the message is new and accepted.
|
||||
func (v *validation) validate(vals []*validatorImpl, src peer.ID, msg *Message, synchronous bool, onValid func(*Message) error) error {
|
||||
// If signature verification is enabled, but signing is disabled,
|
||||
// the Signature is required to be nil upon receiving the message in PubSub.pushMsg.
|
||||
if msg.Signature != nil {
|
||||
@ -306,7 +324,7 @@ func (v *validation) validate(vals []*validatorImpl, src peer.ID, msg *Message,
|
||||
id := v.p.idGen.ID(msg)
|
||||
if !v.p.markSeen(id) {
|
||||
v.tracer.DuplicateMessage(msg)
|
||||
return nil
|
||||
return dupeErr{}
|
||||
} else {
|
||||
v.tracer.ValidateMessage(msg)
|
||||
}
|
||||
@ -345,7 +363,7 @@ loop:
|
||||
select {
|
||||
case v.validateThrottle <- struct{}{}:
|
||||
go func() {
|
||||
v.doValidateTopic(async, src, msg, result)
|
||||
v.doValidateTopic(async, src, msg, result, onValid)
|
||||
<-v.validateThrottle
|
||||
}()
|
||||
default:
|
||||
@ -360,13 +378,8 @@ loop:
|
||||
return ValidationError{Reason: RejectValidationIgnored}
|
||||
}
|
||||
|
||||
// no async validators, accepted message, send it!
|
||||
select {
|
||||
case v.p.sendMsg <- msg:
|
||||
return nil
|
||||
case <-v.p.ctx.Done():
|
||||
return v.p.ctx.Err()
|
||||
}
|
||||
// no async validators, accepted message
|
||||
return onValid(msg)
|
||||
}
|
||||
|
||||
func (v *validation) validateSignature(msg *Message) bool {
|
||||
@ -379,7 +392,7 @@ func (v *validation) validateSignature(msg *Message) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (v *validation) doValidateTopic(vals []*validatorImpl, src peer.ID, msg *Message, r ValidationResult) {
|
||||
func (v *validation) doValidateTopic(vals []*validatorImpl, src peer.ID, msg *Message, r ValidationResult, onValid func(*Message) error) {
|
||||
result := v.validateTopic(vals, src, msg)
|
||||
|
||||
if result == ValidationAccept && r != ValidationAccept {
|
||||
@ -388,7 +401,7 @@ func (v *validation) doValidateTopic(vals []*validatorImpl, src peer.ID, msg *Me
|
||||
|
||||
switch result {
|
||||
case ValidationAccept:
|
||||
v.p.sendMsg <- msg
|
||||
_ = onValid(msg)
|
||||
case ValidationReject:
|
||||
log.Debugf("message validation failed; dropping message from %s", src)
|
||||
v.tracer.RejectMessage(msg, RejectValidationFailed)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user