diff --git a/cmd/waku/flags_rln.go b/cmd/waku/flags_rln.go index e1ef58f7..d5dd3d16 100644 --- a/cmd/waku/flags_rln.go +++ b/cmd/waku/flags_rln.go @@ -73,5 +73,11 @@ func rlnFlags() []cli.Flag { Value: &options.RLNRelay.MembershipContractAddress, }, }, + &cli.IntFlag{ + Name: "rln-relay-bandwidth-threshold", + Value: 0, + Usage: "Message rate in bytes/sec after which verification of proofs should happen. Use 0 to disable bandwidth rate limits", + Destination: &options.RLNRelay.BandwidthThreshold, + }, } } diff --git a/cmd/waku/node_rln.go b/cmd/waku/node_rln.go index 282d5189..afbae66e 100644 --- a/cmd/waku/node_rln.go +++ b/cmd/waku/node_rln.go @@ -26,6 +26,8 @@ func checkForRLN(logger *zap.Logger, options Options, nodeOpts *[]node.WakuNodeO ethPrivKey = options.RLNRelay.ETHPrivateKey } + *nodeOpts = append(*nodeOpts, node.WithRLNBandwidthThreshold(options.RLNRelay.BandwidthThreshold)) + *nodeOpts = append(*nodeOpts, node.WithDynamicRLNRelay( options.RLNRelay.PubsubTopic, options.RLNRelay.ContentTopic, diff --git a/cmd/waku/options.go b/cmd/waku/options.go index bec9492d..1877a2ad 100644 --- a/cmd/waku/options.go +++ b/cmd/waku/options.go @@ -44,6 +44,7 @@ type RLNRelayOptions struct { ETHPrivateKey *ecdsa.PrivateKey ETHClientAddress string MembershipContractAddress common.Address + BandwidthThreshold int } // FilterOptions are settings used to enable filter protocol. This is a protocol diff --git a/waku/v2/node/wakunode2_rln.go b/waku/v2/node/wakunode2_rln.go index 3ba3d6e9..ef72ed80 100644 --- a/waku/v2/node/wakunode2_rln.go +++ b/waku/v2/node/wakunode2_rln.go @@ -12,6 +12,7 @@ import ( "github.com/waku-org/go-waku/waku/v2/protocol/rln/group_manager/static" r "github.com/waku-org/go-zerokit-rln/rln" "go.uber.org/zap" + "golang.org/x/time/rate" ) // RLNRelay is used to access any operation related to Waku RLN protocol @@ -73,7 +74,12 @@ func (w *WakuNode) mountRlnRelay(ctx context.Context) error { } } - rlnRelay, err := rln.New(w.Relay(), groupManager, w.opts.rlnRelayPubsubTopic, w.opts.rlnRelayContentTopic, w.opts.rlnSpamHandler, w.timesource, w.log) + var limiter *rate.Limiter + if w.opts.rlnRelayBandwidthThreshold != 0 { + limiter = rate.NewLimiter(rate.Limit(w.opts.rlnRelayBandwidthThreshold), w.opts.rlnRelayBandwidthThreshold) + } + + rlnRelay, err := rln.New(w.Relay(), groupManager, w.opts.rlnRelayPubsubTopic, w.opts.rlnRelayContentTopic, w.opts.rlnSpamHandler, limiter, w.timesource, w.log) if err != nil { return err } diff --git a/waku/v2/node/wakuoptions.go b/waku/v2/node/wakuoptions.go index 5b39af51..4d5fe4bd 100644 --- a/waku/v2/node/wakuoptions.go +++ b/waku/v2/node/wakuoptions.go @@ -106,6 +106,7 @@ type WakuNodeParameters struct { keystorePassword string rlnMembershipContractAddress common.Address rlnRegistrationHandler func(tx *types.Transaction) + rlnRelayBandwidthThreshold int keepAliveInterval time.Duration diff --git a/waku/v2/node/wakuoptions_rln.go b/waku/v2/node/wakuoptions_rln.go index 263a4f82..636ba4fd 100644 --- a/waku/v2/node/wakuoptions_rln.go +++ b/waku/v2/node/wakuoptions_rln.go @@ -11,6 +11,14 @@ import ( r "github.com/waku-org/go-zerokit-rln/rln" ) +// WithRLNBandwidthThreshold sets the message rate in bytes/sec after which verification of proofs should happen +func WithRLNBandwidthThreshold(rateLimit int) WakuNodeOption { + return func(params *WakuNodeParameters) error { + params.rlnRelayBandwidthThreshold = rateLimit + return nil + } +} + // WithStaticRLNRelay enables the Waku V2 RLN protocol in offchain mode // Requires the `gowaku_rln` build constrain (or the env variable RLN=true if building go-waku) func WithStaticRLNRelay(pubsubTopic string, contentTopic string, memberIndex r.MembershipIndex, spamHandler rln.SpamHandler) WakuNodeOption { diff --git a/waku/v2/protocol/rln/onchain_test.go b/waku/v2/protocol/rln/onchain_test.go index 4e80e815..a444d912 100644 --- a/waku/v2/protocol/rln/onchain_test.go +++ b/waku/v2/protocol/rln/onchain_test.go @@ -242,7 +242,7 @@ func (s *WakuRLNRelayDynamicSuite) TestMerkleTreeConstruction() { gm, err := dynamic.NewDynamicGroupManager(s.clientAddr, s.u1PrivKey, s.rlnAddr, "./test_onchain.json", "", false, nil, utils.Logger()) s.Require().NoError(err) - rlnRelay, err := New(relay, gm, RLNRELAY_PUBSUB_TOPIC, RLNRELAY_CONTENT_TOPIC, nil, timesource.NewDefaultClock(), utils.Logger()) + rlnRelay, err := New(relay, gm, RLNRELAY_PUBSUB_TOPIC, RLNRELAY_CONTENT_TOPIC, nil, nil, timesource.NewDefaultClock(), utils.Logger()) s.Require().NoError(err) // PreRegistering the keypair @@ -286,7 +286,7 @@ func (s *WakuRLNRelayDynamicSuite) TestCorrectRegistrationOfPeers() { gm1, err := dynamic.NewDynamicGroupManager(s.clientAddr, s.u1PrivKey, s.rlnAddr, "./test_onchain.json", "", false, nil, utils.Logger()) s.Require().NoError(err) - rlnRelay1, err := New(relay1, gm1, RLNRELAY_PUBSUB_TOPIC, RLNRELAY_CONTENT_TOPIC, nil, timesource.NewDefaultClock(), utils.Logger()) + rlnRelay1, err := New(relay1, gm1, RLNRELAY_PUBSUB_TOPIC, RLNRELAY_CONTENT_TOPIC, nil, nil, timesource.NewDefaultClock(), utils.Logger()) s.Require().NoError(err) err = rlnRelay1.Start(context.TODO()) s.Require().NoError(err) @@ -312,7 +312,7 @@ func (s *WakuRLNRelayDynamicSuite) TestCorrectRegistrationOfPeers() { gm2, err := dynamic.NewDynamicGroupManager(s.clientAddr, s.u2PrivKey, s.rlnAddr, "./test_onchain.json", "", false, nil, utils.Logger()) s.Require().NoError(err) - rlnRelay2, err := New(relay2, gm2, RLNRELAY_PUBSUB_TOPIC, RLNRELAY_CONTENT_TOPIC, nil, timesource.NewDefaultClock(), utils.Logger()) + rlnRelay2, err := New(relay2, gm2, RLNRELAY_PUBSUB_TOPIC, RLNRELAY_CONTENT_TOPIC, nil, nil, timesource.NewDefaultClock(), utils.Logger()) s.Require().NoError(err) err = rlnRelay2.Start(context.TODO()) s.Require().NoError(err) diff --git a/waku/v2/protocol/rln/rln_relay_test.go b/waku/v2/protocol/rln/rln_relay_test.go index f353254f..9a7a5672 100644 --- a/waku/v2/protocol/rln/rln_relay_test.go +++ b/waku/v2/protocol/rln/rln_relay_test.go @@ -58,7 +58,7 @@ func (s *WakuRLNRelaySuite) TestOffchainMode() { groupManager, err := static.NewStaticGroupManager(groupIDCommitments, idCredential, index, utils.Logger()) s.Require().NoError(err) - wakuRLNRelay, err := New(relay, groupManager, RLNRELAY_PUBSUB_TOPIC, RLNRELAY_CONTENT_TOPIC, nil, timesource.NewDefaultClock(), utils.Logger()) + wakuRLNRelay, err := New(relay, groupManager, RLNRELAY_PUBSUB_TOPIC, RLNRELAY_CONTENT_TOPIC, nil, nil, timesource.NewDefaultClock(), utils.Logger()) s.Require().NoError(err) err = wakuRLNRelay.Start(context.TODO()) diff --git a/waku/v2/protocol/rln/waku_rln_relay.go b/waku/v2/protocol/rln/waku_rln_relay.go index fa2eb9be..ec1540ff 100644 --- a/waku/v2/protocol/rln/waku_rln_relay.go +++ b/waku/v2/protocol/rln/waku_rln_relay.go @@ -18,6 +18,7 @@ import ( "github.com/waku-org/go-waku/waku/v2/timesource" "github.com/waku-org/go-zerokit-rln/rln" "go.uber.org/zap" + "golang.org/x/time/rate" proto "google.golang.org/protobuf/proto" ) @@ -34,6 +35,7 @@ type WakuRLNRelay struct { groupManager GroupManager rootTracker *group_manager.MerkleRootTracker + rateLimiter *rate.Limiter // pubsubTopic is the topic for which rln relay is mounted pubsubTopic string @@ -55,6 +57,7 @@ func New( pubsubTopic string, contentTopic string, spamHandler SpamHandler, + rateLimiter *rate.Limiter, timesource timesource.Timesource, log *zap.Logger) (*WakuRLNRelay, error) { rlnInstance, err := rln.NewRLN() @@ -72,6 +75,7 @@ func New( RLN: rlnInstance, groupManager: groupManager, rootTracker: rootTracker, + rateLimiter: rateLimiter, pubsubTopic: pubsubTopic, contentTopic: contentTopic, relay: relay, @@ -280,26 +284,32 @@ func (rlnRelay *WakuRLNRelay) addValidator( pubsubTopic string, contentTopic string, spamHandler SpamHandler) error { - validator := func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { - rlnRelay.log.Debug("rln-relay topic validator called") + validator := func(ctx context.Context, peerID peer.ID, message *pubsub.Message) pubsub.ValidationResult { + rlnRelay.log.Debug("topic validator called") wakuMessage := &pb.WakuMessage{} if err := proto.Unmarshal(message.Data, wakuMessage); err != nil { rlnRelay.log.Debug("could not unmarshal message") - return true + return pubsub.ValidationReject } // check the contentTopic if (wakuMessage.ContentTopic != "") && (contentTopic != "") && (wakuMessage.ContentTopic != contentTopic) { rlnRelay.log.Debug("content topic did not match", zap.String("contentTopic", contentTopic)) - return true + return pubsub.ValidationAccept } + if rlnRelay.rateLimiter != nil && rlnRelay.rateLimiter.AllowN(time.Now(), len(message.Data)) { + return pubsub.ValidationAccept + } + + rlnRelay.log.Debug("message bandwidth limit exceeded, running rate limit proof validation") + // validate the message validationRes, err := rlnRelay.ValidateMessage(wakuMessage, nil) if err != nil { rlnRelay.log.Debug("validating message", zap.Error(err)) - return false + return pubsub.ValidationReject } switch validationRes { @@ -308,13 +318,13 @@ func (rlnRelay *WakuRLNRelay) addValidator( zap.String("pubsubTopic", pubsubTopic), zap.String("id", hex.EncodeToString(wakuMessage.Hash(pubsubTopic))), ) - return true + return pubsub.ValidationAccept case invalidMessage: rlnRelay.log.Debug("message could not be verified", zap.String("pubsubTopic", pubsubTopic), zap.String("id", hex.EncodeToString(wakuMessage.Hash(pubsubTopic))), ) - return false + return pubsub.ValidationReject case spamMessage: rlnRelay.log.Debug("spam message found", zap.String("pubsubTopic", pubsubTopic), @@ -327,10 +337,10 @@ func (rlnRelay *WakuRLNRelay) addValidator( } } - return false + return pubsub.ValidationReject default: rlnRelay.log.Debug("unhandled validation result", zap.Int("validationResult", int(validationRes))) - return false + return pubsub.ValidationIgnore } }