diff --git a/gossip_tracer.go b/gossip_tracer.go index 04e7719..a40b707 100644 --- a/gossip_tracer.go +++ b/gossip_tracer.go @@ -15,7 +15,7 @@ import ( type gossipTracer struct { sync.Mutex - msgID MsgIdFunction + idGen *msgIDGenerator followUpTime time.Duration @@ -29,7 +29,7 @@ type gossipTracer struct { func newGossipTracer() *gossipTracer { return &gossipTracer{ - msgID: DefaultMsgIdFn, + idGen: newMsgIdGenerator(), promises: make(map[string]map[peer.ID]time.Time), peerPromises: make(map[peer.ID]map[string]struct{}), } @@ -40,7 +40,7 @@ func (gt *gossipTracer) Start(gs *GossipSubRouter) { return } - gt.msgID = gs.p.msgID + gt.idGen = gs.p.idGen gt.followUpTime = gs.params.IWantFollowupTime } @@ -117,7 +117,7 @@ func (gt *gossipTracer) GetBrokenPromises() map[peer.ID]int { var _ RawTracer = (*gossipTracer)(nil) func (gt *gossipTracer) fulfillPromise(msg *Message) { - mid := gt.msgID(msg.Message) + mid := gt.idGen.ID(msg) gt.Lock() defer gt.Unlock() diff --git a/gossipsub.go b/gossipsub.go index 7c2da30..0aa8a00 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -295,7 +295,7 @@ func WithPeerScore(params *PeerScoreParams, thresholds *PeerScoreThresholds) Opt ps.tracer = &pubsubTracer{ raw: []RawTracer{gs.score, gs.gossipTracer}, pid: ps.host.ID(), - msgID: ps.msgID, + idGen: ps.idGen, } } @@ -484,7 +484,7 @@ func (gs *GossipSubRouter) Attach(p *PubSub) { gs.tagTracer.Start(gs) // start using the same msg ID function as PubSub for caching messages. - gs.mcache.SetMsgIdFn(p.msgID) + gs.mcache.SetMsgIdFn(p.idGen.ID) // start the heartbeat go gs.heartbeatTimer() @@ -705,7 +705,7 @@ func (gs *GossipSubRouter) handleIWant(p peer.ID, ctl *pb.ControlMessage) []*pb. continue } - ihave[mid] = msg + ihave[mid] = msg.Message } } @@ -954,7 +954,7 @@ func (gs *GossipSubRouter) connector() { } func (gs *GossipSubRouter) Publish(msg *Message) { - gs.mcache.Put(msg.Message) + gs.mcache.Put(msg) from := msg.ReceivedFrom topic := msg.GetTopic() diff --git a/mcache.go b/mcache.go index e1f02ab..889948e 100644 --- a/mcache.go +++ b/mcache.go @@ -3,8 +3,6 @@ package pubsub import ( "fmt" - pb "github.com/libp2p/go-libp2p-pubsub/pb" - "github.com/libp2p/go-libp2p-core/peer" ) @@ -27,23 +25,25 @@ func NewMessageCache(gossip, history int) *MessageCache { panic(err) } return &MessageCache{ - msgs: make(map[string]*pb.Message), + msgs: make(map[string]*Message), peertx: make(map[string]map[peer.ID]int), history: make([][]CacheEntry, history), gossip: gossip, - msgID: DefaultMsgIdFn, + msgID: func(msg *Message) string { + return DefaultMsgIdFn(msg.Message) + }, } } type MessageCache struct { - msgs map[string]*pb.Message + msgs map[string]*Message peertx map[string]map[peer.ID]int history [][]CacheEntry gossip int - msgID MsgIdFunction + msgID func(*Message) string } -func (mc *MessageCache) SetMsgIdFn(msgID MsgIdFunction) { +func (mc *MessageCache) SetMsgIdFn(msgID func(*Message) string) { mc.msgID = msgID } @@ -52,18 +52,18 @@ type CacheEntry struct { topic string } -func (mc *MessageCache) Put(msg *pb.Message) { +func (mc *MessageCache) Put(msg *Message) { mid := mc.msgID(msg) mc.msgs[mid] = msg mc.history[0] = append(mc.history[0], CacheEntry{mid: mid, topic: msg.GetTopic()}) } -func (mc *MessageCache) Get(mid string) (*pb.Message, bool) { +func (mc *MessageCache) Get(mid string) (*Message, bool) { m, ok := mc.msgs[mid] return m, ok } -func (mc *MessageCache) GetForPeer(mid string, p peer.ID) (*pb.Message, int, bool) { +func (mc *MessageCache) GetForPeer(mid string, p peer.ID) (*Message, int, bool) { m, ok := mc.msgs[mid] if !ok { return nil, 0, false diff --git a/mcache_test.go b/mcache_test.go index e36c6b1..93bcfdc 100644 --- a/mcache_test.go +++ b/mcache_test.go @@ -18,7 +18,7 @@ func TestMessageCache(t *testing.T) { } for i := 0; i < 10; i++ { - mcache.Put(msgs[i]) + mcache.Put(&Message{Message: msgs[i]}) } for i := 0; i < 10; i++ { @@ -28,7 +28,7 @@ func TestMessageCache(t *testing.T) { t.Fatalf("Message %d not in cache", i) } - if m != msgs[i] { + if m.Message != msgs[i] { t.Fatalf("Message %d does not match cache", i) } } @@ -47,7 +47,7 @@ func TestMessageCache(t *testing.T) { mcache.Shift() for i := 10; i < 20; i++ { - mcache.Put(msgs[i]) + mcache.Put(&Message{Message: msgs[i]}) } for i := 0; i < 20; i++ { @@ -57,7 +57,7 @@ func TestMessageCache(t *testing.T) { t.Fatalf("Message %d not in cache", i) } - if m != msgs[i] { + if m.Message != msgs[i] { t.Fatalf("Message %d does not match cache", i) } } @@ -83,22 +83,22 @@ func TestMessageCache(t *testing.T) { mcache.Shift() for i := 20; i < 30; i++ { - mcache.Put(msgs[i]) + mcache.Put(&Message{Message: msgs[i]}) } mcache.Shift() for i := 30; i < 40; i++ { - mcache.Put(msgs[i]) + mcache.Put(&Message{Message: msgs[i]}) } mcache.Shift() for i := 40; i < 50; i++ { - mcache.Put(msgs[i]) + mcache.Put(&Message{Message: msgs[i]}) } mcache.Shift() for i := 50; i < 60; i++ { - mcache.Put(msgs[i]) + mcache.Put(&Message{Message: msgs[i]}) } if len(mcache.msgs) != 50 { @@ -120,7 +120,7 @@ func TestMessageCache(t *testing.T) { t.Fatalf("Message %d not in cache", i) } - if m != msgs[i] { + if m.Message != msgs[i] { t.Fatalf("Message %d does not match cache", i) } } diff --git a/midgen.go b/midgen.go index 0329329..d09c87d 100644 --- a/midgen.go +++ b/midgen.go @@ -1,21 +1,34 @@ package pubsub -import "sync" +import ( + "sync" +) +// msgIDGenerator handles computing IDs for msgs +// It allows setting custom generators(MsgIdFunction) per topic type msgIDGenerator struct { - defGen MsgIdFunction + Default MsgIdFunction topicGens map[string]MsgIdFunction topicGensLk sync.RWMutex } -func (m *msgIDGenerator) Add(topic string, gen MsgIdFunction) { +func newMsgIdGenerator() *msgIDGenerator{ + return &msgIDGenerator{ + Default: DefaultMsgIdFn, + topicGens: make(map[string]MsgIdFunction), + } +} + +// Set sets custom id generator(MsgIdFunction) for topic. +func (m *msgIDGenerator) Set(topic string, gen MsgIdFunction) { m.topicGensLk.Lock() m.topicGens[topic] = gen m.topicGensLk.Unlock() } -func (m *msgIDGenerator) GenID(msg *Message) string { +// ID computes ID for the msg or short-circuits with the cached value. +func (m *msgIDGenerator) ID(msg *Message) string { if msg.ID != "" { return msg.ID } @@ -24,7 +37,7 @@ func (m *msgIDGenerator) GenID(msg *Message) string { gen, ok := m.topicGens[msg.GetTopic()] m.topicGensLk.RUnlock() if !ok { - gen = m.defGen + gen = m.Default } msg.ID = gen(msg.Message) diff --git a/peer_gater.go b/peer_gater.go index e334324..3da2755 100644 --- a/peer_gater.go +++ b/peer_gater.go @@ -182,7 +182,7 @@ func WithPeerGater(params *PeerGaterParams) Option { ps.tracer = &pubsubTracer{ raw: []RawTracer{gs.gate}, pid: ps.host.ID(), - msgID: ps.msgID, + idGen: ps.idGen, } } diff --git a/pubsub.go b/pubsub.go index cba16c5..7caf5ad 100644 --- a/pubsub.go +++ b/pubsub.go @@ -20,7 +20,7 @@ import ( "github.com/libp2p/go-libp2p-core/protocol" logging "github.com/ipfs/go-log" - timecache "github.com/whyrusleeping/timecache" + "github.com/whyrusleeping/timecache" ) // DefaultMaximumMessageSize is 1mb. @@ -147,8 +147,8 @@ type PubSub struct { seenMessagesMx sync.Mutex seenMessages *timecache.TimeCache - // function used to compute the ID for a message - msgID MsgIdFunction + // generator used to compute the ID for a message + idGen *msgIDGenerator // key for signing messages; nil when signing is disabled signKey crypto.PrivKey @@ -273,7 +273,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option blacklist: NewMapBlacklist(), blacklistPeer: make(chan peer.ID), seenMessages: timecache.NewTimeCache(TimeCacheDuration), - msgID: DefaultMsgIdFn, + idGen: newMsgIdGenerator(), counter: uint64(time.Now().UnixNano()), } @@ -327,11 +327,7 @@ type MsgIdFunction func(pmsg *pb.Message) string // but it can be customized to e.g. the hash of the message. func WithMessageIdFn(fn MsgIdFunction) Option { return func(p *PubSub) error { - p.msgID = fn - // the tracer Option may already be set. Update its message ID function to make options order-independent. - if p.tracer != nil { - p.tracer.msgID = fn - } + p.idGen.Default = fn return nil } } @@ -456,7 +452,7 @@ func WithEventTracer(tracer EventTracer) Option { if p.tracer != nil { p.tracer.tracer = tracer } else { - p.tracer = &pubsubTracer{tracer: tracer, pid: p.host.ID(), msgID: p.msgID} + p.tracer = &pubsubTracer{tracer: tracer, pid: p.host.ID(), idGen: p.idGen} } return nil } @@ -469,7 +465,7 @@ func WithRawTracer(tracer RawTracer) Option { if p.tracer != nil { p.tracer.raw = append(p.tracer.raw, tracer) } else { - p.tracer = &pubsubTracer{raw: []RawTracer{tracer}, pid: p.host.ID(), msgID: p.msgID} + p.tracer = &pubsubTracer{raw: []RawTracer{tracer}, pid: p.host.ID(), idGen: p.idGen} } return nil } @@ -1097,7 +1093,7 @@ func (p *PubSub) pushMsg(msg *Message) { } // have we already seen and validated this message? - id := p.msgID(msg.Message) + id := p.idGen.ID(msg) if p.seenMessage(id) { p.tracer.DuplicateMessage(msg) return diff --git a/score.go b/score.go index 1ee6141..87753a1 100644 --- a/score.go +++ b/score.go @@ -76,7 +76,7 @@ type peerScore struct { // message delivery tracking deliveries *messageDeliveries - msgID MsgIdFunction + idGen *msgIDGenerator host host.Host // debugging inspection @@ -183,7 +183,7 @@ func newPeerScore(params *PeerScoreParams) *peerScore { peerStats: make(map[peer.ID]*peerStats), peerIPs: make(map[string]map[peer.ID]struct{}), deliveries: &messageDeliveries{records: make(map[string]*deliveryRecord)}, - msgID: DefaultMsgIdFn, + idGen: newMsgIdGenerator(), } } @@ -239,7 +239,7 @@ func (ps *peerScore) Start(gs *GossipSubRouter) { return } - ps.msgID = gs.p.msgID + ps.idGen = gs.p.idGen ps.host = gs.p.host go ps.background(gs.p.ctx) } @@ -689,7 +689,7 @@ func (ps *peerScore) ValidateMessage(msg *Message) { // the pubsub subsystem is beginning validation; create a record to track time in // the validation pipeline with an accurate firstSeen time. - _ = ps.deliveries.getRecord(ps.msgID(msg.Message)) + _ = ps.deliveries.getRecord(ps.idGen.ID(msg)) } func (ps *peerScore) DeliverMessage(msg *Message) { @@ -698,7 +698,7 @@ func (ps *peerScore) DeliverMessage(msg *Message) { ps.markFirstMessageDelivery(msg.ReceivedFrom, msg) - drec := ps.deliveries.getRecord(ps.msgID(msg.Message)) + drec := ps.deliveries.getRecord(ps.idGen.ID(msg)) // defensive check that this is the first delivery trace -- delivery status should be unknown if drec.status != deliveryUnknown { @@ -749,7 +749,7 @@ func (ps *peerScore) RejectMessage(msg *Message, reason string) { return } - drec := ps.deliveries.getRecord(ps.msgID(msg.Message)) + drec := ps.deliveries.getRecord(ps.idGen.ID(msg)) // defensive check that this is the first rejection trace -- delivery status should be unknown if drec.status != deliveryUnknown { @@ -789,7 +789,7 @@ func (ps *peerScore) DuplicateMessage(msg *Message) { ps.Lock() defer ps.Unlock() - drec := ps.deliveries.getRecord(ps.msgID(msg.Message)) + drec := ps.deliveries.getRecord(ps.idGen.ID(msg)) _, ok := drec.peers[msg.ReceivedFrom] if ok { diff --git a/tag_tracer.go b/tag_tracer.go index 65e99be..ae7318d 100644 --- a/tag_tracer.go +++ b/tag_tracer.go @@ -44,9 +44,9 @@ var ( type tagTracer struct { sync.RWMutex - cmgr connmgr.ConnManager - msgID MsgIdFunction - decayer connmgr.Decayer + cmgr connmgr.ConnManager + idGen *msgIDGenerator + decayer connmgr.Decayer decaying map[string]connmgr.DecayingTag direct map[peer.ID]struct{} @@ -62,7 +62,7 @@ func newTagTracer(cmgr connmgr.ConnManager) *tagTracer { } return &tagTracer{ cmgr: cmgr, - msgID: DefaultMsgIdFn, + idGen: newMsgIdGenerator(), decayer: decayer, decaying: make(map[string]connmgr.DecayingTag), nearFirst: make(map[string]map[peer.ID]struct{}), @@ -74,7 +74,7 @@ func (t *tagTracer) Start(gs *GossipSubRouter) { return } - t.msgID = gs.p.msgID + t.idGen = gs.p.idGen t.direct = gs.direct } @@ -162,7 +162,7 @@ func (t *tagTracer) bumpTagsForMessage(p peer.ID, msg *Message) { func (t *tagTracer) nearFirstPeers(msg *Message) []peer.ID { t.Lock() defer t.Unlock() - peersMap, ok := t.nearFirst[t.msgID(msg.Message)] + peersMap, ok := t.nearFirst[t.idGen.ID(msg)] if !ok { return nil } @@ -194,7 +194,7 @@ func (t *tagTracer) DeliverMessage(msg *Message) { // delete the delivery state for this message t.Lock() - delete(t.nearFirst, t.msgID(msg.Message)) + delete(t.nearFirst, t.idGen.ID(msg)) t.Unlock() } @@ -215,7 +215,7 @@ func (t *tagTracer) ValidateMessage(msg *Message) { defer t.Unlock() // create map to start tracking the peers who deliver while we're validating - id := t.msgID(msg.Message) + id := t.idGen.ID(msg) if _, exists := t.nearFirst[id]; exists { return } @@ -226,7 +226,7 @@ func (t *tagTracer) DuplicateMessage(msg *Message) { t.Lock() defer t.Unlock() - id := t.msgID(msg.Message) + id := t.idGen.ID(msg) peers, ok := t.nearFirst[id] if !ok { return @@ -247,7 +247,7 @@ func (t *tagTracer) RejectMessage(msg *Message, reason string) { case RejectValidationIgnored: fallthrough case RejectValidationFailed: - delete(t.nearFirst, t.msgID(msg.Message)) + delete(t.nearFirst, t.idGen.ID(msg)) } } diff --git a/trace.go b/trace.go index 3232542..7efd665 100644 --- a/trace.go +++ b/trace.go @@ -64,7 +64,7 @@ type pubsubTracer struct { tracer EventTracer raw []RawTracer pid peer.ID - msgID MsgIdFunction + idGen *msgIDGenerator } func (t *pubsubTracer) PublishMessage(msg *Message) { @@ -82,7 +82,7 @@ func (t *pubsubTracer) PublishMessage(msg *Message) { PeerID: []byte(t.pid), Timestamp: &now, PublishMessage: &pb.TraceEvent_PublishMessage{ - MessageID: []byte(t.msgID(msg.Message)), + MessageID: []byte(t.idGen.ID(msg)), Topic: msg.Message.Topic, }, } @@ -123,7 +123,7 @@ func (t *pubsubTracer) RejectMessage(msg *Message, reason string) { PeerID: []byte(t.pid), Timestamp: &now, RejectMessage: &pb.TraceEvent_RejectMessage{ - MessageID: []byte(t.msgID(msg.Message)), + MessageID: []byte(t.idGen.ID(msg)), ReceivedFrom: []byte(msg.ReceivedFrom), Reason: &reason, Topic: msg.Topic, @@ -154,7 +154,7 @@ func (t *pubsubTracer) DuplicateMessage(msg *Message) { PeerID: []byte(t.pid), Timestamp: &now, DuplicateMessage: &pb.TraceEvent_DuplicateMessage{ - MessageID: []byte(t.msgID(msg.Message)), + MessageID: []byte(t.idGen.ID(msg)), ReceivedFrom: []byte(msg.ReceivedFrom), Topic: msg.Topic, }, @@ -184,7 +184,7 @@ func (t *pubsubTracer) DeliverMessage(msg *Message) { PeerID: []byte(t.pid), Timestamp: &now, DeliverMessage: &pb.TraceEvent_DeliverMessage{ - MessageID: []byte(t.msgID(msg.Message)), + MessageID: []byte(t.idGen.ID(msg)), Topic: msg.Topic, ReceivedFrom: []byte(msg.ReceivedFrom), }, @@ -344,7 +344,7 @@ func (t *pubsubTracer) traceRPCMeta(rpc *RPC) *pb.TraceEvent_RPCMeta { var msgs []*pb.TraceEvent_MessageMeta for _, m := range rpc.Publish { msgs = append(msgs, &pb.TraceEvent_MessageMeta{ - MessageID: []byte(t.msgID(m)), + MessageID: []byte(t.idGen.ID(&Message{Message: m})), Topic: m.Topic, }) } diff --git a/validation.go b/validation.go index 35d291a..9fa28d2 100644 --- a/validation.go +++ b/validation.go @@ -284,7 +284,7 @@ func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message, synch // we can mark the message as seen now that we have verified the signature // and avoid invoking user validators more than once - id := v.p.msgID(msg.Message) + id := v.p.idGen.ID(msg) if !v.p.markSeen(id) { v.tracer.DuplicateMessage(msg) return nil