diff --git a/floodsub_test.go b/floodsub_test.go index 9af2cac..e671372 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -905,7 +905,7 @@ func TestWithSigning(t *testing.T) { defer cancel() hosts := getNetHosts(t, ctx, 2) - psubs := getPubsubs(ctx, hosts, WithMessageSigning()) + psubs := getPubsubs(ctx, hosts, WithMessageSigning(true)) connect(t, hosts[0], hosts[1]) diff --git a/pubsub.go b/pubsub.go index cd1291a..091bbc2 100644 --- a/pubsub.go +++ b/pubsub.go @@ -92,6 +92,8 @@ type PubSub struct { // key for signing messages; nil when signing is disabled (default for now) signKey crypto.PrivKey + // strict mode rejects all unsigned messages prior to validation + signStrict bool ctx context.Context } @@ -190,9 +192,10 @@ func WithValidateThrottle(n int) Option { } } -func WithMessageSigning() Option { +func WithMessageSigning(strict bool) Option { return func(p *PubSub) error { p.signKey = p.host.Peerstore().PrivKey(p.host.ID()) + p.signStrict = strict return nil } } @@ -457,6 +460,12 @@ func msgID(pmsg *pb.Message) string { // pushMsg pushes a message performing validation as necessary func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) { + // reject unsigned messages when strict before we even process the id + if p.signStrict && msg.Signature == nil { + log.Debugf("dropping unsigned message from %s", src) + return + } + id := msgID(msg.Message) if p.seenMessage(id) { return