From 7beaa3f02939069027a0435d0859c25e7d443068 Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Thu, 7 Sep 2023 17:39:10 -0400 Subject: [PATCH] feat(rln-relay): ensure execution order for pubsub validators --- examples/chat2/exec.go | 2 +- waku/v2/node/wakunode2.go | 5 +- waku/v2/node/wakunode2_rln.go | 8 +- waku/v2/node/wakuoptions.go | 2 +- waku/v2/protocol/envelope.go | 4 +- waku/v2/protocol/relay/validators.go | 99 +++++++++++++++-------- waku/v2/protocol/relay/validators_test.go | 31 ++----- waku/v2/protocol/relay/waku_relay.go | 18 +++-- waku/v2/protocol/relay/waku_relay_test.go | 9 ++- waku/v2/protocol/rln/common.go | 2 +- waku/v2/protocol/rln/waku_rln_relay.go | 49 +++++------ 11 files changed, 122 insertions(+), 107 deletions(-) diff --git a/examples/chat2/exec.go b/examples/chat2/exec.go index 08e04edd..45caa49b 100644 --- a/examples/chat2/exec.go +++ b/examples/chat2/exec.go @@ -43,7 +43,7 @@ func execute(options Options) { } if options.RLNRelay.Enable { - spamHandler := func(message *pb.WakuMessage) error { + spamHandler := func(message *pb.WakuMessage, topic string) error { return nil } diff --git a/waku/v2/node/wakunode2.go b/waku/v2/node/wakunode2.go index 311e436e..d49043c5 100644 --- a/waku/v2/node/wakunode2.go +++ b/waku/v2/node/wakunode2.go @@ -10,7 +10,6 @@ import ( backoffv4 "github.com/cenkalti/backoff/v4" golog "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p" - pubsub "github.com/libp2p/go-libp2p-pubsub" "go.uber.org/zap" "github.com/ethereum/go-ethereum/crypto" @@ -66,13 +65,13 @@ type IdentityCredential = struct { IDCommitment byte32 `json:"idCommitment"` } -type SpamHandler = func(message *pb.WakuMessage) error +type SpamHandler = func(message *pb.WakuMessage, topic string) error type RLNRelay interface { IdentityCredential() (IdentityCredential, error) MembershipIndex() uint AppendRLNProof(msg *pb.WakuMessage, senderEpochTime time.Time) error - Validator(spamHandler SpamHandler) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool + Validator(spamHandler SpamHandler) func(ctx context.Context, message *pb.WakuMessage, topic string) bool Start(ctx context.Context) error Stop() error } diff --git a/waku/v2/node/wakunode2_rln.go b/waku/v2/node/wakunode2_rln.go index 695cb214..4d93c04b 100644 --- a/waku/v2/node/wakunode2_rln.go +++ b/waku/v2/node/wakunode2_rln.go @@ -8,7 +8,6 @@ import ( "context" "errors" - pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/waku-org/go-waku/waku/v2/protocol/rln" "github.com/waku-org/go-waku/waku/v2/protocol/rln/group_manager" "github.com/waku-org/go-waku/waku/v2/protocol/rln/group_manager/dynamic" @@ -29,6 +28,10 @@ func (w *WakuNode) setupRLNRelay() error { return nil } + if !w.opts.enableRelay { + return errors.New("rln requires relay") + } + var groupManager group_manager.GroupManager rlnInstance, rootTracker, err := rln.GetRLNInstanceAndRootTracker(w.opts.rlnTreePath) @@ -85,8 +88,7 @@ func (w *WakuNode) setupRLNRelay() error { w.rlnRelay = rlnRelay - // Adding RLN as a default validator - w.opts.pubsubOpts = append(w.opts.pubsubOpts, pubsub.WithDefaultValidator(rlnRelay.Validator(w.opts.rlnSpamHandler))) + w.Relay().RegisterDefaultValidator(w.rlnRelay.Validator(w.opts.rlnSpamHandler)) return nil } diff --git a/waku/v2/node/wakuoptions.go b/waku/v2/node/wakuoptions.go index e02610f4..9b334224 100644 --- a/waku/v2/node/wakuoptions.go +++ b/waku/v2/node/wakuoptions.go @@ -96,7 +96,7 @@ type WakuNodeParameters struct { enableRLN bool rlnRelayMemIndex *uint rlnRelayDynamic bool - rlnSpamHandler func(message *pb.WakuMessage) error + rlnSpamHandler func(message *pb.WakuMessage, topic string) error rlnETHClientAddress string keystorePath string keystorePassword string diff --git a/waku/v2/protocol/envelope.go b/waku/v2/protocol/envelope.go index c87ee143..6b0a0f7b 100644 --- a/waku/v2/protocol/envelope.go +++ b/waku/v2/protocol/envelope.go @@ -20,12 +20,12 @@ type Envelope struct { // as well as generating a hash based on the bytes that compose the message func NewEnvelope(msg *wpb.WakuMessage, receiverTime int64, pubSubTopic string) *Envelope { messageHash := msg.Hash(pubSubTopic) - hash := hash.SHA256([]byte(msg.ContentTopic), msg.Payload) + digest := hash.SHA256([]byte(msg.ContentTopic), msg.Payload) return &Envelope{ msg: msg, hash: messageHash, index: &pb.Index{ - Digest: hash[:], + Digest: digest[:], ReceiverTime: receiverTime, SenderTime: msg.Timestamp, PubsubTopic: pubSubTopic, diff --git a/waku/v2/protocol/relay/validators.go b/waku/v2/protocol/relay/validators.go index 14490870..1405179a 100644 --- a/waku/v2/protocol/relay/validators.go +++ b/waku/v2/protocol/relay/validators.go @@ -10,14 +10,14 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/secp256k1" - pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/peer" + proto "google.golang.org/protobuf/proto" + "github.com/waku-org/go-waku/waku/v2/hash" "github.com/waku-org/go-waku/waku/v2/protocol/pb" "github.com/waku-org/go-waku/waku/v2/timesource" "go.uber.org/zap" - proto "google.golang.org/protobuf/proto" ) func msgHash(pubSubTopic string, msg *pb.WakuMessage) []byte { @@ -38,6 +38,68 @@ func msgHash(pubSubTopic string, msg *pb.WakuMessage) []byte { ) } +type validatorFn = func(ctx context.Context, msg *pb.WakuMessage, topic string) bool + +func (w *WakuRelay) RegisterDefaultValidator(fn validatorFn) { + w.topicValidatorMutex.Lock() + defer w.topicValidatorMutex.Unlock() + w.defaultTopicValidators = append(w.defaultTopicValidators, fn) +} + +func (w *WakuRelay) RegisterTopicValidator(topic string, fn validatorFn) { + w.topicValidatorMutex.Lock() + defer w.topicValidatorMutex.Unlock() + + w.topicValidators[topic] = append(w.topicValidators[topic], fn) +} + +func (w *WakuRelay) RemoveTopicValidator(topic string) { + w.topicValidatorMutex.Lock() + defer w.topicValidatorMutex.Unlock() + + delete(w.topicValidators, topic) +} + +func (w *WakuRelay) topicValidator(topic string) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { + return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { + msg := new(pb.WakuMessage) + err := proto.Unmarshal(message.Data, msg) + if err != nil { + return false + } + + w.topicValidatorMutex.RLock() + validators, exists := w.topicValidators[topic] + validators = append(validators, w.defaultTopicValidators...) + w.topicValidatorMutex.RUnlock() + + if exists { + for _, v := range validators { + if !v(ctx, msg, topic) { + return false + } + } + } + + return true + } +} + +// AddSignedTopicValidator registers a gossipsub validator for a topic which will check that messages Meta field contains a valid ECDSA signature for the specified pubsub topic. This is used as a DoS prevention mechanism +func (w *WakuRelay) AddSignedTopicValidator(topic string, publicKey *ecdsa.PublicKey) error { + w.log.Info("adding validator to signed topic", zap.String("topic", topic), zap.String("publicKey", hex.EncodeToString(elliptic.Marshal(publicKey.Curve, publicKey.X, publicKey.Y)))) + + fn := signedTopicBuilder(w.timesource, publicKey) + + w.RegisterTopicValidator(topic, fn) + + if !w.IsSubscribed(topic) { + w.log.Warn("relay is not subscribed to signed topic", zap.String("topic", topic)) + } + + return nil +} + const messageWindowDuration = time.Minute * 5 func withinTimeWindow(t timesource.Timesource, msg *pb.WakuMessage) bool { @@ -51,17 +113,9 @@ func withinTimeWindow(t timesource.Timesource, msg *pb.WakuMessage) bool { return now.Sub(msgTime).Abs() <= messageWindowDuration } -type validatorFn = func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool - -func validatorFnBuilder(t timesource.Timesource, topic string, publicKey *ecdsa.PublicKey) (validatorFn, error) { +func signedTopicBuilder(t timesource.Timesource, publicKey *ecdsa.PublicKey) validatorFn { publicKeyBytes := crypto.FromECDSAPub(publicKey) - return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { - msg := new(pb.WakuMessage) - err := proto.Unmarshal(message.Data, msg) - if err != nil { - return false - } - + return func(ctx context.Context, msg *pb.WakuMessage, topic string) bool { if !withinTimeWindow(t, msg) { return false } @@ -70,28 +124,7 @@ func validatorFnBuilder(t timesource.Timesource, topic string, publicKey *ecdsa. signature := msg.Meta return secp256k1.VerifySignature(publicKeyBytes, msgHash, signature) - }, nil -} - -// AddSignedTopicValidator registers a gossipsub validator for a topic which will check that messages Meta field contains a valid ECDSA signature for the specified pubsub topic. This is used as a DoS prevention mechanism -func (w *WakuRelay) AddSignedTopicValidator(topic string, publicKey *ecdsa.PublicKey) error { - w.log.Info("adding validator to signed topic", zap.String("topic", topic), zap.String("publicKey", hex.EncodeToString(elliptic.Marshal(publicKey.Curve, publicKey.X, publicKey.Y)))) - - fn, err := validatorFnBuilder(w.timesource, topic, publicKey) - if err != nil { - return err } - - err = w.pubsub.RegisterTopicValidator(topic, fn) - if err != nil { - return err - } - - if !w.IsSubscribed(topic) { - w.log.Warn("relay is not subscribed to signed topic", zap.String("topic", topic)) - } - - return nil } // SignMessage adds an ECDSA signature to a WakuMessage as an opt-in mechanism for DoS prevention diff --git a/waku/v2/protocol/relay/validators_test.go b/waku/v2/protocol/relay/validators_test.go index 383c1fcb..9f03590c 100644 --- a/waku/v2/protocol/relay/validators_test.go +++ b/waku/v2/protocol/relay/validators_test.go @@ -7,11 +7,8 @@ import ( "time" "github.com/ethereum/go-ethereum/crypto" - pubsub "github.com/libp2p/go-libp2p-pubsub" - pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb" "github.com/stretchr/testify/require" "github.com/waku-org/go-waku/waku/v2/protocol/pb" - "google.golang.org/protobuf/proto" ) type FakeTimesource struct { @@ -59,39 +56,23 @@ func TestMsgHash(t *testing.T) { // expectedSignature, _ := hex.DecodeString("127FA211B2514F0E974A055392946DC1A14052182A6ABEFB8A6CD7C51DA1BF2E40595D28EF1A9488797C297EED3AAC45430005FB3A7F037BDD9FC4BD99F59E63") // require.True(t, bytes.Equal(expectedSignature, msg.Meta)) - msgData, _ := proto.Marshal(msg) - //expectedMessageHash, _ := hex.DecodeString("662F8C20A335F170BD60ABC1F02AD66F0C6A6EE285DA2A53C95259E7937C0AE9") //messageHash := MsgHash(pubsubTopic, msg) //require.True(t, bytes.Equal(expectedMessageHash, messageHash)) - myValidator, err := validatorFnBuilder(NewFakeTimesource(timestamp), protectedPubSubTopic, &prvKey.PublicKey) - require.NoError(t, err) - result := myValidator(context.Background(), "", &pubsub.Message{ - Message: &pubsub_pb.Message{ - Data: msgData, - }, - }) + myValidator := signedTopicBuilder(NewFakeTimesource(timestamp), &prvKey.PublicKey) + result := myValidator(context.Background(), msg, protectedPubSubTopic) require.True(t, result) // Exceed 5m window in both directions now5m1sInPast := timestamp.Add(-5 * time.Minute).Add(-1 * time.Second) - myValidator, err = validatorFnBuilder(NewFakeTimesource(now5m1sInPast), protectedPubSubTopic, &prvKey.PublicKey) + myValidator = signedTopicBuilder(NewFakeTimesource(now5m1sInPast), &prvKey.PublicKey) require.NoError(t, err) - result = myValidator(context.Background(), "", &pubsub.Message{ - Message: &pubsub_pb.Message{ - Data: msgData, - }, - }) + result = myValidator(context.Background(), msg, protectedPubSubTopic) require.False(t, result) now5m1sInFuture := timestamp.Add(5 * time.Minute).Add(1 * time.Second) - myValidator, err = validatorFnBuilder(NewFakeTimesource(now5m1sInFuture), protectedPubSubTopic, &prvKey.PublicKey) - require.NoError(t, err) - result = myValidator(context.Background(), "", &pubsub.Message{ - Message: &pubsub_pb.Message{ - Data: msgData, - }, - }) + myValidator = signedTopicBuilder(NewFakeTimesource(now5m1sInFuture), &prvKey.PublicKey) + result = myValidator(context.Background(), msg, protectedPubSubTopic) require.False(t, result) } diff --git a/waku/v2/protocol/relay/waku_relay.go b/waku/v2/protocol/relay/waku_relay.go index 51efae51..672732ac 100644 --- a/waku/v2/protocol/relay/waku_relay.go +++ b/waku/v2/protocol/relay/waku_relay.go @@ -49,6 +49,10 @@ type WakuRelay struct { minPeersToPublish int + topicValidatorMutex sync.RWMutex + topicValidators map[string][]validatorFn + defaultTopicValidators []validatorFn + // TODO: convert to concurrent maps topicsMutex sync.Mutex wakuRelayTopics map[string]*pubsub.Topic @@ -83,6 +87,7 @@ func NewWakuRelay(bcaster Broadcaster, minPeersToPublish int, timesource timesou w.timesource = timesource w.wakuRelayTopics = make(map[string]*pubsub.Topic) w.relaySubs = make(map[string]*pubsub.Subscription) + w.topicValidators = make(map[string][]validatorFn) w.bcaster = bcaster w.minPeersToPublish = minPeersToPublish w.CommonService = waku_proto.NewCommonService() @@ -177,12 +182,6 @@ func NewWakuRelay(bcaster Broadcaster, minPeersToPublish int, timesource timesou pubsub.WithSeenMessagesTTL(2 * time.Minute), pubsub.WithPeerScore(w.peerScoreParams, w.peerScoreThresholds), pubsub.WithPeerScoreInspect(w.peerScoreInspector, 6*time.Second), - // TODO: to improve - setup default validator only if no default validator has been set. - pubsub.WithDefaultValidator(func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { - msg := new(pb.WakuMessage) - err := proto.Unmarshal(message.Data, msg) - return err == nil - }), }, opts...) return w @@ -270,6 +269,11 @@ func (w *WakuRelay) upsertTopic(topic string) (*pubsub.Topic, error) { pubSubTopic, ok := w.wakuRelayTopics[topic] if !ok { // Joins topic if node hasn't joined yet + err := w.pubsub.RegisterTopicValidator(topic, w.topicValidator(topic)) + if err != nil { + return nil, err + } + newTopic, err := w.pubsub.Join(string(topic)) if err != nil { return nil, err @@ -419,6 +423,8 @@ func (w *WakuRelay) Unsubscribe(ctx context.Context, topic string) error { } delete(w.wakuRelayTopics, topic) + w.RemoveTopicValidator(topic) + err = w.emitters.EvtRelayUnsubscribed.Emit(EvtRelayUnsubscribed{topic}) if err != nil { return err diff --git a/waku/v2/protocol/relay/waku_relay_test.go b/waku/v2/protocol/relay/waku_relay_test.go index aeb7eb9a..c7b335d8 100644 --- a/waku/v2/protocol/relay/waku_relay_test.go +++ b/waku/v2/protocol/relay/waku_relay_test.go @@ -79,10 +79,6 @@ func TestGossipsubScore(t *testing.T) { relay := make([]*WakuRelay, 5) for i := 0; i < 5; i++ { hosts[i], relay[i] = createRelayNode(t) - if i == 0 { - // This is a hack to remove the default validator from the list of default options - relay[i].opts = relay[i].opts[:len(relay[i].opts)-1] - } err := relay[i].Start(context.Background()) require.NoError(t, err) } @@ -119,6 +115,11 @@ func TestGossipsubScore(t *testing.T) { // We obtain the go-libp2p topic directly because we normally can't publish anything other than WakuMessages pubsubTopic, err := relay[0].upsertTopic(testTopic) require.NoError(t, err) + + // Removing validator from relayer0 to allow it to send invalid messages + err = relay[0].pubsub.UnregisterTopicValidator(testTopic) + require.NoError(t, err) + for i := 0; i < 50; i++ { buf := make([]byte, 1000) _, err := rand.Read(buf) diff --git a/waku/v2/protocol/rln/common.go b/waku/v2/protocol/rln/common.go index 69412d4c..1f3afd71 100644 --- a/waku/v2/protocol/rln/common.go +++ b/waku/v2/protocol/rln/common.go @@ -26,7 +26,7 @@ const acceptableRootWindowSize = 5 type RegistrationHandler = func(tx *types.Transaction) -type SpamHandler = func(message *pb.WakuMessage) error +type SpamHandler = func(msg *pb.WakuMessage, topic string) error func toRLNSignal(wakuMessage *pb.WakuMessage) []byte { if wakuMessage == nil { diff --git a/waku/v2/protocol/rln/waku_rln_relay.go b/waku/v2/protocol/rln/waku_rln_relay.go index 4e489a7d..d3c00eb4 100644 --- a/waku/v2/protocol/rln/waku_rln_relay.go +++ b/waku/v2/protocol/rln/waku_rln_relay.go @@ -2,14 +2,11 @@ package rln import ( "context" - "encoding/hex" "errors" "math" "time" "github.com/ethereum/go-ethereum/log" - pubsub "github.com/libp2p/go-libp2p-pubsub" - "github.com/libp2p/go-libp2p/core/peer" "github.com/prometheus/client_golang/prometheus" "github.com/waku-org/go-waku/logging" "github.com/waku-org/go-waku/waku/v2/protocol/pb" @@ -17,7 +14,6 @@ import ( "github.com/waku-org/go-waku/waku/v2/timesource" "github.com/waku-org/go-zerokit-rln/rln" "go.uber.org/zap" - proto "google.golang.org/protobuf/proto" ) type WakuRLNRelay struct { @@ -218,52 +214,49 @@ func (rlnRelay *WakuRLNRelay) AppendRLNProof(msg *pb.WakuMessage, senderEpochTim // Validator returns a validator for the waku messages. // The message validation logic is according to https://rfc.vac.dev/spec/17/ func (rlnRelay *WakuRLNRelay) Validator( - spamHandler SpamHandler) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { - return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { - rlnRelay.log.Debug("rln-relay topic validator called") + spamHandler SpamHandler) func(ctx context.Context, msg *pb.WakuMessage, topic string) bool { + return func(ctx context.Context, msg *pb.WakuMessage, topic string) bool { + + hash := msg.Hash(topic) + + log := rlnRelay.log.With( + logging.HexBytes("hash", hash), + zap.String("pubsubTopic", topic), + zap.String("contentTopic", msg.ContentTopic), + ) + + log.Debug("rln-relay topic validator called") rlnRelay.metrics.RecordMessage() - wakuMessage := &pb.WakuMessage{} - if err := proto.Unmarshal(message.Data, wakuMessage); err != nil { - rlnRelay.log.Debug("could not unmarshal message") - return true - } - // validate the message - validationRes, err := rlnRelay.ValidateMessage(wakuMessage, nil) + validationRes, err := rlnRelay.ValidateMessage(msg, nil) if err != nil { - rlnRelay.log.Debug("validating message", zap.Error(err)) + log.Debug("validating message", zap.Error(err)) return false } switch validationRes { case validMessage: - rlnRelay.log.Debug("message verified", - zap.String("id", hex.EncodeToString([]byte(message.ID))), - ) + log.Debug("message verified") return true case invalidMessage: - rlnRelay.log.Debug("message could not be verified", - zap.String("id", hex.EncodeToString([]byte(message.ID))), - ) + log.Debug("message could not be verified") return false case spamMessage: - rlnRelay.log.Debug("spam message found", - zap.String("id", hex.EncodeToString([]byte(message.ID))), - ) + log.Debug("spam message found") - rlnRelay.metrics.RecordSpam(wakuMessage.ContentTopic) + rlnRelay.metrics.RecordSpam(msg.ContentTopic) if spamHandler != nil { - if err := spamHandler(wakuMessage); err != nil { - rlnRelay.log.Error("executing spam handler", zap.Error(err)) + if err := spamHandler(msg, topic); err != nil { + log.Error("executing spam handler", zap.Error(err)) } } return false default: - rlnRelay.log.Debug("unhandled validation result", zap.Int("validationResult", int(validationRes))) + log.Debug("unhandled validation result", zap.Int("validationResult", int(validationRes))) return false } }