chore: RLN tests coverage improvement (#1003)

This commit is contained in:
Roman Zajic 2024-01-11 22:52:52 +08:00 committed by GitHub
parent 82fc800b08
commit 75f975ce7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 171 additions and 6 deletions

View File

@ -170,6 +170,11 @@ func (s *WakuRLNRelaySuite) TestValidateMessage() {
groupKeyPairs, _, err := r.CreateMembershipList(100) groupKeyPairs, _, err := r.CreateMembershipList(100)
s.Require().NoError(err) s.Require().NoError(err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pubSubTopic := "/waku/2/go/rln/test"
var groupIDCommitments []r.IDCommitment var groupIDCommitments []r.IDCommitment
for _, c := range groupKeyPairs { for _, c := range groupKeyPairs {
groupIDCommitments = append(groupIDCommitments, c.IDCommitment) groupIDCommitments = append(groupIDCommitments, c.IDCommitment)
@ -191,6 +196,7 @@ func (s *WakuRLNRelaySuite) TestValidateMessage() {
s.Require().NoError(err) s.Require().NoError(err)
rlnRelay := &WakuRLNRelay{ rlnRelay := &WakuRLNRelay{
timesource: timesource.NewDefaultClock(),
Details: group_manager.Details{ Details: group_manager.Details{
GroupManager: groupManager, GroupManager: groupManager,
RootTracker: rootTracker, RootTracker: rootTracker,
@ -201,14 +207,17 @@ func (s *WakuRLNRelaySuite) TestValidateMessage() {
metrics: newMetrics(prometheus.DefaultRegisterer), metrics: newMetrics(prometheus.DefaultRegisterer),
} }
//get the current epoch time // get the current epoch time
now := time.Now() now := time.Now()
err = groupManager.Start(context.Background()) err = groupManager.Start(context.Background())
s.Require().NoError(err) s.Require().NoError(err)
// create some messages from the same peer and append rln proof to them, except wm4 // Get Validator func instance
validator := rlnRelay.Validator(nil)
s.Require().NotNil(validator)
// create some messages from the same peer and append rln proof to them, except wm4
wm1 := &pb.WakuMessage{Payload: []byte("Valid message")} wm1 := &pb.WakuMessage{Payload: []byte("Valid message")}
err = rlnRelay.AppendRLNProof(wm1, now) err = rlnRelay.AppendRLNProof(wm1, now)
s.Require().NoError(err) s.Require().NoError(err)
@ -245,4 +254,159 @@ func (s *WakuRLNRelaySuite) TestValidateMessage() {
s.Require().Equal(spamMessage, msgValidate2) s.Require().Equal(spamMessage, msgValidate2)
s.Require().Equal(validMessage, msgValidate3) s.Require().Equal(validMessage, msgValidate3)
s.Require().Equal(invalidMessage, msgValidate4) s.Require().Equal(invalidMessage, msgValidate4)
// Create valid message and check it with validator func
wm10 := &pb.WakuMessage{Payload: []byte("Valid message 2")}
err = rlnRelay.AppendRLNProof(wm10, now.Add(2*time.Second*time.Duration(r.EPOCH_UNIT_SECONDS)))
s.Require().NoError(err)
isValid := validator(ctx, wm10, pubSubTopic)
s.Require().True(isValid)
// Detect spam message with validator func
wm11 := &pb.WakuMessage{Payload: []byte("Spam 2")}
err = rlnRelay.AppendRLNProof(wm11, now.Add(2*time.Second*time.Duration(r.EPOCH_UNIT_SECONDS)))
s.Require().NoError(err)
isValid = validator(ctx, wm11, pubSubTopic)
s.Require().False(isValid)
// Detect invalid message (no proof) with validator func
wm12 := &pb.WakuMessage{Payload: []byte("Invalid message 2")}
isValid = validator(ctx, wm12, pubSubTopic)
s.Require().False(isValid)
}
func (s *WakuRLNRelaySuite) TestRLNRelayGetters() {
port, err := tests.FindFreePort(s.T(), "", 5)
s.Require().NoError(err)
ctx := context.Background()
host, err := tests.MakeHost(ctx, port, rand.Reader)
s.Require().NoError(err)
bcaster := relay.NewBroadcaster(1024)
relay := relay.NewWakuRelay(bcaster, 0, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, utils.Logger())
relay.SetHost(host)
err = bcaster.Start(ctx)
s.Require().NoError(err)
err = relay.Start(ctx)
s.Require().NoError(err)
defer relay.Stop()
groupKeyPairs, _, err := r.CreateMembershipList(100)
s.Require().NoError(err)
var groupIDCommitments []r.IDCommitment
for _, c := range groupKeyPairs {
groupIDCommitments = append(groupIDCommitments, c.IDCommitment)
}
rlnInstance, rootTracker, err := GetRLNInstanceAndRootTracker("root")
s.Require().NoError(err)
// Set index
index := r.MembershipIndex(5)
idCredential := groupKeyPairs[index]
groupManager, err := static.NewStaticGroupManager(groupIDCommitments, idCredential, index, rlnInstance, rootTracker, utils.Logger())
s.Require().NoError(err)
wakuRLNRelay := New(group_manager.Details{
GroupManager: groupManager,
RootTracker: rootTracker,
RLN: rlnInstance,
}, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, utils.Logger())
err = wakuRLNRelay.Start(ctx)
s.Require().NoError(err)
// Test IdentityCredential
_, err = wakuRLNRelay.IdentityCredential()
s.Require().NoError(err)
// Test MembershipIndex
i := wakuRLNRelay.MembershipIndex()
s.Require().Equal(i, uint(5))
// Test IsReady
_, err = wakuRLNRelay.IsReady(ctx)
s.Require().NoError(err)
// Test Stop
err = wakuRLNRelay.Stop()
s.Require().NoError(err)
}
func (s *WakuRLNRelaySuite) TestEdgeCasesValidateMessage() {
groupKeyPairs, _, err := r.CreateMembershipList(10)
s.Require().NoError(err)
var groupIDCommitments []r.IDCommitment
for _, c := range groupKeyPairs {
groupIDCommitments = append(groupIDCommitments, c.IDCommitment)
}
index := r.MembershipIndex(5)
// Create a RLN instance
rlnInstance, err := r.NewRLN()
s.Require().NoError(err)
rootTracker := group_manager.NewMerkleRootTracker(acceptableRootWindowSize, rlnInstance)
idCredential := groupKeyPairs[index]
groupManager, err := static.NewStaticGroupManager(groupIDCommitments, idCredential, index, rlnInstance, rootTracker, utils.Logger())
s.Require().NoError(err)
rlnRelay := &WakuRLNRelay{
timesource: timesource.NewDefaultClock(),
Details: group_manager.Details{
GroupManager: groupManager,
RootTracker: rootTracker,
RLN: rlnInstance,
},
nullifierLog: NewNullifierLog(context.TODO(), utils.Logger()),
log: utils.Logger(),
metrics: newMetrics(prometheus.DefaultRegisterer),
}
// Get the current epoch time
now := time.Now()
err = groupManager.Start(context.Background())
s.Require().NoError(err)
// Valid message
wm1 := &pb.WakuMessage{Payload: []byte("Valid message")}
err = rlnRelay.AppendRLNProof(wm1, now)
s.Require().NoError(err)
// Valid message with very old epoch
wm2 := &pb.WakuMessage{Payload: []byte("Invalid message")}
err = rlnRelay.AppendRLNProof(wm2, now.Add(-100*time.Second*time.Duration(r.EPOCH_UNIT_SECONDS)))
s.Require().NoError(err)
// Test when no msg is provided
_, err = rlnRelay.ValidateMessage(nil, &now)
s.Require().Error(err)
// Test valid message with no optionalTime provided
msgValidate1, err := rlnRelay.ValidateMessage(wm1, nil)
s.Require().NoError(err)
s.Require().Equal(validMessage, msgValidate1)
// Test corrupted RateLimitProof case
wm1.RateLimitProof[1] = 'o'
_, err = rlnRelay.ValidateMessage(wm1, &now)
s.Require().Error(err)
// Test message's epoch is too old
msgValidate2, err := rlnRelay.ValidateMessage(wm2, nil)
s.Require().NoError(err)
s.Require().Equal(invalidMessage, msgValidate2)
} }

View File

@ -108,8 +108,9 @@ func (rlnRelay *WakuRLNRelay) ValidateMessage(msg *pb.WakuMessage, optionalTime
msgProof, err := BytesToRateLimitProof(msg.RateLimitProof) msgProof, err := BytesToRateLimitProof(msg.RateLimitProof)
if err != nil { if err != nil {
rlnRelay.log.Debug("invalid message: could not extract proof", zap.Error(err)) rlnRelay.log.Debug("invalid message: could not extract proof")
rlnRelay.metrics.RecordInvalidMessage(proofExtractionErr) rlnRelay.metrics.RecordInvalidMessage(proofExtractionErr)
return validationError, err
} }
if msgProof == nil { if msgProof == nil {
@ -146,9 +147,9 @@ func (rlnRelay *WakuRLNRelay) ValidateMessage(msg *pb.WakuMessage, optionalTime
start := time.Now() start := time.Now()
valid, err := rlnRelay.verifyProof(msg, msgProof) valid, err := rlnRelay.verifyProof(msg, msgProof)
if err != nil { if err != nil {
rlnRelay.log.Debug("could not verify proof", zap.Error(err)) rlnRelay.log.Debug("could not verify proof")
rlnRelay.metrics.RecordError(proofVerificationErr) rlnRelay.metrics.RecordError(proofVerificationErr)
return invalidMessage, nil return validationError, err
} }
rlnRelay.metrics.RecordProofVerification(time.Since(start)) rlnRelay.metrics.RecordProofVerification(time.Since(start))
@ -270,7 +271,7 @@ func (rlnRelay *WakuRLNRelay) Validator(
return false return false
default: default:
log.Debug("unhandled validation result", zap.Int("validationResult", int(validationRes))) log.Error("unhandled validation result", zap.Int("validationResult", int(validationRes)))
return false return false
} }
} }