From 75f975ce7a520fd6aac450147eff534efb5e342d Mon Sep 17 00:00:00 2001 From: Roman Zajic Date: Thu, 11 Jan 2024 22:52:52 +0800 Subject: [PATCH] chore: RLN tests coverage improvement (#1003) --- waku/v2/protocol/rln/rln_relay_test.go | 168 ++++++++++++++++++++++++- waku/v2/protocol/rln/waku_rln_relay.go | 9 +- 2 files changed, 171 insertions(+), 6 deletions(-) diff --git a/waku/v2/protocol/rln/rln_relay_test.go b/waku/v2/protocol/rln/rln_relay_test.go index e9f8ed1f..78038d58 100644 --- a/waku/v2/protocol/rln/rln_relay_test.go +++ b/waku/v2/protocol/rln/rln_relay_test.go @@ -170,6 +170,11 @@ func (s *WakuRLNRelaySuite) TestValidateMessage() { groupKeyPairs, _, err := r.CreateMembershipList(100) s.Require().NoError(err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pubSubTopic := "/waku/2/go/rln/test" + var groupIDCommitments []r.IDCommitment for _, c := range groupKeyPairs { groupIDCommitments = append(groupIDCommitments, c.IDCommitment) @@ -191,6 +196,7 @@ func (s *WakuRLNRelaySuite) TestValidateMessage() { s.Require().NoError(err) rlnRelay := &WakuRLNRelay{ + timesource: timesource.NewDefaultClock(), Details: group_manager.Details{ GroupManager: groupManager, RootTracker: rootTracker, @@ -201,14 +207,17 @@ func (s *WakuRLNRelaySuite) TestValidateMessage() { metrics: newMetrics(prometheus.DefaultRegisterer), } - //get the current epoch time + // get the current epoch time now := time.Now() err = groupManager.Start(context.Background()) s.Require().NoError(err) - // create some messages from the same peer and append rln proof to them, except wm4 + // Get Validator func instance + validator := rlnRelay.Validator(nil) + s.Require().NotNil(validator) + // create some messages from the same peer and append rln proof to them, except wm4 wm1 := &pb.WakuMessage{Payload: []byte("Valid message")} err = rlnRelay.AppendRLNProof(wm1, now) s.Require().NoError(err) @@ -245,4 +254,159 @@ func (s *WakuRLNRelaySuite) TestValidateMessage() { s.Require().Equal(spamMessage, msgValidate2) s.Require().Equal(validMessage, msgValidate3) s.Require().Equal(invalidMessage, msgValidate4) + + // Create valid message and check it with validator func + wm10 := &pb.WakuMessage{Payload: []byte("Valid message 2")} + err = rlnRelay.AppendRLNProof(wm10, now.Add(2*time.Second*time.Duration(r.EPOCH_UNIT_SECONDS))) + s.Require().NoError(err) + + isValid := validator(ctx, wm10, pubSubTopic) + s.Require().True(isValid) + + // Detect spam message with validator func + wm11 := &pb.WakuMessage{Payload: []byte("Spam 2")} + err = rlnRelay.AppendRLNProof(wm11, now.Add(2*time.Second*time.Duration(r.EPOCH_UNIT_SECONDS))) + s.Require().NoError(err) + + isValid = validator(ctx, wm11, pubSubTopic) + s.Require().False(isValid) + + // Detect invalid message (no proof) with validator func + wm12 := &pb.WakuMessage{Payload: []byte("Invalid message 2")} + + isValid = validator(ctx, wm12, pubSubTopic) + s.Require().False(isValid) + +} + +func (s *WakuRLNRelaySuite) TestRLNRelayGetters() { + port, err := tests.FindFreePort(s.T(), "", 5) + s.Require().NoError(err) + + ctx := context.Background() + + host, err := tests.MakeHost(ctx, port, rand.Reader) + s.Require().NoError(err) + bcaster := relay.NewBroadcaster(1024) + relay := relay.NewWakuRelay(bcaster, 0, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, utils.Logger()) + relay.SetHost(host) + err = bcaster.Start(ctx) + s.Require().NoError(err) + err = relay.Start(ctx) + s.Require().NoError(err) + defer relay.Stop() + + groupKeyPairs, _, err := r.CreateMembershipList(100) + s.Require().NoError(err) + + var groupIDCommitments []r.IDCommitment + for _, c := range groupKeyPairs { + groupIDCommitments = append(groupIDCommitments, c.IDCommitment) + } + + rlnInstance, rootTracker, err := GetRLNInstanceAndRootTracker("root") + s.Require().NoError(err) + + // Set index + index := r.MembershipIndex(5) + + idCredential := groupKeyPairs[index] + groupManager, err := static.NewStaticGroupManager(groupIDCommitments, idCredential, index, rlnInstance, rootTracker, utils.Logger()) + s.Require().NoError(err) + + wakuRLNRelay := New(group_manager.Details{ + GroupManager: groupManager, + RootTracker: rootTracker, + RLN: rlnInstance, + }, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, utils.Logger()) + + err = wakuRLNRelay.Start(ctx) + s.Require().NoError(err) + + // Test IdentityCredential + _, err = wakuRLNRelay.IdentityCredential() + s.Require().NoError(err) + + // Test MembershipIndex + i := wakuRLNRelay.MembershipIndex() + s.Require().Equal(i, uint(5)) + + // Test IsReady + _, err = wakuRLNRelay.IsReady(ctx) + s.Require().NoError(err) + + // Test Stop + err = wakuRLNRelay.Stop() + s.Require().NoError(err) + +} + +func (s *WakuRLNRelaySuite) TestEdgeCasesValidateMessage() { + groupKeyPairs, _, err := r.CreateMembershipList(10) + s.Require().NoError(err) + + var groupIDCommitments []r.IDCommitment + for _, c := range groupKeyPairs { + groupIDCommitments = append(groupIDCommitments, c.IDCommitment) + } + + index := r.MembershipIndex(5) + + // Create a RLN instance + rlnInstance, err := r.NewRLN() + s.Require().NoError(err) + + rootTracker := group_manager.NewMerkleRootTracker(acceptableRootWindowSize, rlnInstance) + + idCredential := groupKeyPairs[index] + groupManager, err := static.NewStaticGroupManager(groupIDCommitments, idCredential, index, rlnInstance, rootTracker, utils.Logger()) + s.Require().NoError(err) + + rlnRelay := &WakuRLNRelay{ + timesource: timesource.NewDefaultClock(), + Details: group_manager.Details{ + GroupManager: groupManager, + RootTracker: rootTracker, + RLN: rlnInstance, + }, + nullifierLog: NewNullifierLog(context.TODO(), utils.Logger()), + log: utils.Logger(), + metrics: newMetrics(prometheus.DefaultRegisterer), + } + + // Get the current epoch time + now := time.Now() + + err = groupManager.Start(context.Background()) + s.Require().NoError(err) + + // Valid message + wm1 := &pb.WakuMessage{Payload: []byte("Valid message")} + err = rlnRelay.AppendRLNProof(wm1, now) + s.Require().NoError(err) + + // Valid message with very old epoch + wm2 := &pb.WakuMessage{Payload: []byte("Invalid message")} + err = rlnRelay.AppendRLNProof(wm2, now.Add(-100*time.Second*time.Duration(r.EPOCH_UNIT_SECONDS))) + s.Require().NoError(err) + + // Test when no msg is provided + _, err = rlnRelay.ValidateMessage(nil, &now) + s.Require().Error(err) + + // Test valid message with no optionalTime provided + msgValidate1, err := rlnRelay.ValidateMessage(wm1, nil) + s.Require().NoError(err) + s.Require().Equal(validMessage, msgValidate1) + + // Test corrupted RateLimitProof case + wm1.RateLimitProof[1] = 'o' + _, err = rlnRelay.ValidateMessage(wm1, &now) + s.Require().Error(err) + + // Test message's epoch is too old + msgValidate2, err := rlnRelay.ValidateMessage(wm2, nil) + s.Require().NoError(err) + s.Require().Equal(invalidMessage, msgValidate2) + } diff --git a/waku/v2/protocol/rln/waku_rln_relay.go b/waku/v2/protocol/rln/waku_rln_relay.go index ad7bc368..39b71806 100644 --- a/waku/v2/protocol/rln/waku_rln_relay.go +++ b/waku/v2/protocol/rln/waku_rln_relay.go @@ -108,8 +108,9 @@ func (rlnRelay *WakuRLNRelay) ValidateMessage(msg *pb.WakuMessage, optionalTime msgProof, err := BytesToRateLimitProof(msg.RateLimitProof) if err != nil { - rlnRelay.log.Debug("invalid message: could not extract proof", zap.Error(err)) + rlnRelay.log.Debug("invalid message: could not extract proof") rlnRelay.metrics.RecordInvalidMessage(proofExtractionErr) + return validationError, err } if msgProof == nil { @@ -146,9 +147,9 @@ func (rlnRelay *WakuRLNRelay) ValidateMessage(msg *pb.WakuMessage, optionalTime start := time.Now() valid, err := rlnRelay.verifyProof(msg, msgProof) if err != nil { - rlnRelay.log.Debug("could not verify proof", zap.Error(err)) + rlnRelay.log.Debug("could not verify proof") rlnRelay.metrics.RecordError(proofVerificationErr) - return invalidMessage, nil + return validationError, err } rlnRelay.metrics.RecordProofVerification(time.Since(start)) @@ -270,7 +271,7 @@ func (rlnRelay *WakuRLNRelay) Validator( return false default: - log.Debug("unhandled validation result", zap.Int("validationResult", int(validationRes))) + log.Error("unhandled validation result", zap.Int("validationResult", int(validationRes))) return false } }