diff --git a/gossipsub.go b/gossipsub.go index 6327dfa..31c5724 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -75,6 +75,8 @@ func (gs *GossipSubRouter) Protocols() []protocol.ID { func (gs *GossipSubRouter) Attach(p *PubSub) { gs.p = p gs.tracer = p.tracer + // start using the same msg ID function as PubSub for caching messages. + gs.mcache.SetMsgIdFn(p.msgID) go gs.heartbeatTimer() } diff --git a/mcache.go b/mcache.go index 9f5e5dc..e085297 100644 --- a/mcache.go +++ b/mcache.go @@ -28,6 +28,7 @@ func NewMessageCache(gossip, history int) *MessageCache { msgs: make(map[string]*pb.Message), history: make([][]CacheEntry, history), gossip: gossip, + msgID: DefaultMsgIdFn, } } @@ -35,6 +36,11 @@ type MessageCache struct { msgs map[string]*pb.Message history [][]CacheEntry gossip int + msgID MsgIdFunction +} + +func (mc *MessageCache) SetMsgIdFn(msgID MsgIdFunction) { + mc.msgID = msgID } type CacheEntry struct { @@ -43,7 +49,7 @@ type CacheEntry struct { } func (mc *MessageCache) Put(msg *pb.Message) { - mid := msgID(msg) + mid := mc.msgID(msg) mc.msgs[mid] = msg mc.history[0] = append(mc.history[0], CacheEntry{mid: mid, topics: msg.GetTopicIDs()}) } diff --git a/mcache_test.go b/mcache_test.go index 5616c21..e8a50f3 100644 --- a/mcache_test.go +++ b/mcache_test.go @@ -10,6 +10,7 @@ import ( func TestMessageCache(t *testing.T) { mcache := NewMessageCache(3, 5) + msgID := DefaultMsgIdFn msgs := make([]*pb.Message, 60) for i := range msgs { diff --git a/pubsub.go b/pubsub.go index 530c47a..b8bd579 100644 --- a/pubsub.go +++ b/pubsub.go @@ -117,6 +117,9 @@ type PubSub struct { seenMessagesMx sync.Mutex seenMessages *timecache.TimeCache + // function used to compute the ID for a message + msgID MsgIdFunction + // key for signing messages; nil when signing is disabled (default for now) signKey crypto.PrivKey // source ID for signed messages; corresponds to signKey @@ -208,6 +211,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option blacklist: NewMapBlacklist(), blacklistPeer: make(chan peer.ID), seenMessages: timecache.NewTimeCache(TimeCacheDuration), + msgID: DefaultMsgIdFn, counter: uint64(time.Now().UnixNano()), } @@ -240,6 +244,24 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option return ps, nil } +// MsgIdFunction returns a unique ID for the passed Message, and PubSub can be customized to use any +// implementation of this function by configuring it with the Option from WithMessageIdFn. +type MsgIdFunction func(pmsg *pb.Message) string + +// WithMessageIdFn is an option to customize the way a message ID is computed for a pubsub message. +// The default ID function is DefaultMsgIdFn (concatenate source and seq nr.), +// but it can be customized to e.g. the hash of the message. +func WithMessageIdFn(fn MsgIdFunction) Option { + return func(p *PubSub) error { + p.msgID = fn + // the tracer Option may already be set. Update its message ID function to make options order-independent. + if p.tracer != nil { + p.tracer.msgID = fn + } + return nil + } +} + // WithPeerOutboundQueueSize is an option to set the buffer size for outbound messages to a peer // We start dropping messages to a peer if the outbound queue if full func WithPeerOutboundQueueSize(size int) Option { @@ -326,7 +348,7 @@ func WithDiscovery(d discovery.Discovery, opts ...DiscoverOpt) Option { // WithEventTracer provides a tracer for the pubsub system func WithEventTracer(tracer EventTracer) Option { return func(p *PubSub) error { - p.tracer = &pubsubTracer{tracer: tracer, pid: p.host.ID()} + p.tracer = &pubsubTracer{tracer: tracer, pid: p.host.ID(), msgID: p.msgID} return nil } } @@ -730,8 +752,8 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) { p.rt.HandleRPC(rpc) } -// msgID returns a unique ID of the passed Message -func msgID(pmsg *pb.Message) string { +// DefaultMsgIdFn returns a unique ID of the passed Message +func DefaultMsgIdFn(pmsg *pb.Message) string { return string(pmsg.GetFrom()) + string(pmsg.GetSeqno()) } @@ -760,7 +782,7 @@ func (p *PubSub) pushMsg(msg *Message) { } // have we already seen and validated this message? - id := msgID(msg.Message) + id := p.msgID(msg.Message) if p.seenMessage(id) { p.tracer.DuplicateMessage(msg) return diff --git a/trace.go b/trace.go index 293f166..1f98053 100644 --- a/trace.go +++ b/trace.go @@ -18,6 +18,7 @@ type EventTracer interface { type pubsubTracer struct { tracer EventTracer pid peer.ID + msgID MsgIdFunction } func (t *pubsubTracer) PublishMessage(msg *Message) { @@ -31,7 +32,7 @@ func (t *pubsubTracer) PublishMessage(msg *Message) { PeerID: []byte(t.pid), Timestamp: &now, PublishMessage: &pb.TraceEvent_PublishMessage{ - MessageID: []byte(msgID(msg.Message)), + MessageID: []byte(t.msgID(msg.Message)), Topics: msg.Message.TopicIDs, }, } @@ -50,7 +51,7 @@ func (t *pubsubTracer) RejectMessage(msg *Message, reason string) { PeerID: []byte(t.pid), Timestamp: &now, RejectMessage: &pb.TraceEvent_RejectMessage{ - MessageID: []byte(msgID(msg.Message)), + MessageID: []byte(t.msgID(msg.Message)), ReceivedFrom: []byte(msg.ReceivedFrom), Reason: &reason, }, @@ -70,7 +71,7 @@ func (t *pubsubTracer) DuplicateMessage(msg *Message) { PeerID: []byte(t.pid), Timestamp: &now, DuplicateMessage: &pb.TraceEvent_DuplicateMessage{ - MessageID: []byte(msgID(msg.Message)), + MessageID: []byte(t.msgID(msg.Message)), ReceivedFrom: []byte(msg.ReceivedFrom), }, } @@ -89,7 +90,7 @@ func (t *pubsubTracer) DeliverMessage(msg *Message) { PeerID: []byte(t.pid), Timestamp: &now, DeliverMessage: &pb.TraceEvent_DeliverMessage{ - MessageID: []byte(msgID(msg.Message)), + MessageID: []byte(t.msgID(msg.Message)), }, } @@ -146,7 +147,7 @@ func (t *pubsubTracer) RecvRPC(rpc *RPC) { Timestamp: &now, RecvRPC: &pb.TraceEvent_RecvRPC{ ReceivedFrom: []byte(rpc.from), - Meta: traceRPCMeta(rpc), + Meta: t.traceRPCMeta(rpc), }, } @@ -165,7 +166,7 @@ func (t *pubsubTracer) SendRPC(rpc *RPC, p peer.ID) { Timestamp: &now, SendRPC: &pb.TraceEvent_SendRPC{ SendTo: []byte(rpc.from), - Meta: traceRPCMeta(rpc), + Meta: t.traceRPCMeta(rpc), }, } @@ -184,20 +185,20 @@ func (t *pubsubTracer) DropRPC(rpc *RPC, p peer.ID) { Timestamp: &now, DropRPC: &pb.TraceEvent_DropRPC{ SendTo: []byte(rpc.from), - Meta: traceRPCMeta(rpc), + Meta: t.traceRPCMeta(rpc), }, } t.tracer.Trace(evt) } -func traceRPCMeta(rpc *RPC) *pb.TraceEvent_RPCMeta { +func (t *pubsubTracer) traceRPCMeta(rpc *RPC) *pb.TraceEvent_RPCMeta { rpcMeta := new(pb.TraceEvent_RPCMeta) var msgs []*pb.TraceEvent_MessageMeta for _, m := range rpc.Publish { msgs = append(msgs, &pb.TraceEvent_MessageMeta{ - MessageID: []byte(msgID(m)), + MessageID: []byte(t.msgID(m)), Topics: m.TopicIDs, }) } diff --git a/validation.go b/validation.go index bef86da..70b4cde 100644 --- a/validation.go +++ b/validation.go @@ -201,7 +201,7 @@ func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message) { // we can mark the message as seen now that we have verified the signature // and avoid invoking user validators more than once - id := msgID(msg.Message) + id := v.p.msgID(msg.Message) if !v.p.markSeen(id) { v.tracer.DuplicateMessage(msg) return