diff --git a/waku/waku_rln_relay/group_manager/on_chain/group_manager.nim b/waku/waku_rln_relay/group_manager/on_chain/group_manager.nim index 65c5fd551..611e24fc2 100644 --- a/waku/waku_rln_relay/group_manager/on_chain/group_manager.nim +++ b/waku/waku_rln_relay/group_manager/on_chain/group_manager.nim @@ -46,8 +46,6 @@ contract(WakuRlnContract): proc MAX_MESSAGE_LIMIT(): UInt256 {.view.} # this function returns the merkleProof for a given index proc merkleProofElements(index: Uint256): seq[Uint256] {.view.} - # this function returns the current Merkle root of the on-chain Merkle tree - proc root(): UInt256 {.view.} type WakuRlnContractWithSender = Sender[WakuRlnContract] @@ -70,7 +68,7 @@ type validRootBuffer*: Deque[MerkleNode] # interval loop to shut down gracefully blockFetchingActive*: bool - merkleProofCache*: Table[Uint256, seq[Uint256]] + merkleProofsByIndex*: Table[Uint256, seq[Uint256]] type Witness* = object ## Represents the custom witness for generating an RLN proof identity_secret*: seq[byte] # Identity secret (private key) @@ -117,20 +115,15 @@ template retryWrapper( retryWrapper(res, RetryStrategy.new(), errStr, g.onFatalErrorAction): body -proc fetchMerkleRootFromContract(g: OnchainGroupManager): Future[UInt256] {.async.} = - ## Fetches the latest Merkle root from the smart contract - let contract = g.wakuRlnContract.get() - let rootInvocation = contract.root() # This returns a ContractInvocation - let root = - await rootInvocation.call() # Convert ContractInvocation to Future and await - return root - -proc cacheMerkleProofs*(g: OnchainGroupManager, index: Uint256) {.async.} = +proc fetchMerkleProof*(g: OnchainGroupManager, index: Uint256) {.async.} = ## Fetches and caches the Merkle proof elements for a given index - let merkleProofInvocation = g.wakuRlnContract.get().merkleProofElements(index) - let merkleProof = - await merkleProofInvocation.call() # Await the contract call and extract the result - g.merkleProofCache[index] = merkleProof + try: + let merkleProofInvocation = g.wakuRlnContract.get().merkleProofElements(index) + let merkleProof = await merkleProofInvocation.call() + # Await the contract call and extract the result + g.merkleProofsByIndex[index] = merkleProof + except CatchableError: + error "Failed to fetch merkle proof: " & getCurrentExceptionMsg() proc setMetadata*( g: OnchainGroupManager, lastProcessedBlock = none(BlockNumber) @@ -275,7 +268,7 @@ method generateProof*( epoch: Epoch, messageId: MessageId, rlnIdentifier = DefaultRlnIdentifier, -): GroupManagerResult[RateLimitProof] {.gcsafe, raises: [].} = +): Future[GroupManagerResult[RateLimitProof]] {.async, gcsafe, raises: [].} = ## Generates an RLN proof using the cached Merkle proof and custom witness # Ensure identity credentials and membership index are set if g.idCredentials.isNone(): @@ -286,10 +279,14 @@ method generateProof*( return err("user message limit is not set") # Retrieve the cached Merkle proof for the membership index - let index = g.membershipIndex.get() - let merkleProof = g.merkleProofCache.getOrDefault(stuint(uint64(index), 256)) - if merkleProof.len == 0: - return err("Merkle proof not found in cache") + let index = stuint(g.membershipIndex.get(), 256) + + if not g.merkleProofsByIndex.hasKey(index): + await g.fetchMerkleProof(index) + let merkle_proof = g.merkleProofsByIndex[index] + + if merkle_proof.len == 0: + return err("Merkle proof not found") # Prepare the witness let witness = Witness( @@ -311,7 +308,6 @@ method generateProof*( if not success: return err("Failed to generate proof") - # Parse the proof into a RateLimitProof object var proofValue = cast[ptr array[320, byte]](outputBuffer.`ptr`) let proofBytes: array[320, byte] = proofValue[] @@ -334,7 +330,8 @@ method generateProof*( discard zkproof.copyFrom(proofBytes[0 .. proofOffset - 1]) discard proofRoot.copyFrom(proofBytes[proofOffset .. rootOffset - 1]) - discard externalNullifier.copyFrom(proofBytes[rootOffset .. externalNullifierOffset - 1]) + discard + externalNullifier.copyFrom(proofBytes[rootOffset .. externalNullifierOffset - 1]) discard shareX.copyFrom(proofBytes[externalNullifierOffset .. shareXOffset - 1]) discard shareY.copyFrom(proofBytes[shareXOffset .. shareYOffset - 1]) discard nullifier.copyFrom(proofBytes[shareYOffset .. nullifierOffset - 1]) @@ -473,6 +470,11 @@ proc handleEvents( rateCommitments = rateCommitments, toRemoveIndices = removalIndices, ) + + for i in 0 ..< rateCommitments.len: + let index = startIndex + MembershipIndex(i) + await g.fetchMerkleProof(stuint(index, 256)) + g.latestIndex = startIndex + MembershipIndex(rateCommitments.len) trace "new members added to the Merkle tree", commitments = rateCommitments.mapIt(it.inHex) @@ -493,6 +495,12 @@ proc handleRemovedEvents( if members.anyIt(it[1]): numRemovedBlocks += 1 + # Remove cached merkleProof for each removed member + for member in members: + if member[1]: # Check if the member is removed + let index = member[0].index + g.merkleProofsByIndex.del(stuint(index, 256)) + await g.backfillRootQueue(numRemovedBlocks) proc getAndHandleEvents(