diff --git a/waku/waku_rln_relay/group_manager/on_chain/group_manager.nim b/waku/waku_rln_relay/group_manager/on_chain/group_manager.nim index ac608da59..3180cded9 100644 --- a/waku/waku_rln_relay/group_manager/on_chain/group_manager.nim +++ b/waku/waku_rln_relay/group_manager/on_chain/group_manager.nim @@ -106,10 +106,11 @@ proc seqToField*(s: seq[byte]): array[32, byte] = for i in 0 ..< len: result[i] = s[i] -proc uint64ToIndex*(value: uint64, numBits: int = 64): seq[uint8] = - result = newSeq[uint8](numBits) - for i in 0 ..< numBits: - result[i] = uint8((value shr i) and 1) +# Convert membership index to 20-bit LSB-first binary sequence +proc uint64ToIndex(index: MembershipIndex, depth: int): seq[byte] = + result = newSeq[byte](depth) + for i in 0 ..< depth: + result[i] = byte((index shr i) and 1) # LSB-first bit decomposition proc fetchMerkleProofElements*( g: OnchainGroupManager @@ -325,6 +326,30 @@ method withdrawBatch*( ): Future[void] {.async: (raises: [Exception]).} = initializedGuard(g) +proc poseidonHash( + g: OnchainGroupManager, elements: seq[byte], bits: seq[byte] +): GroupManagerResult[array[32, byte]] = + # Compute leaf hash from idCommitment + let leafHashRes = poseidon(@[g.idCredentials.get().idCommitment]) + if leafHashRes.isErr(): + return err("Failed to compute leaf hash: " & leafHashRes.error) + + var hash = leafHashRes.get() + for i in 0 ..< bits.len: + let sibling = elements[i * 32 .. (i + 1) * 32 - 1] + + let hashRes = + if bits[i] == 0: + poseidon(@[@hash, sibling]) + else: + poseidon(@[sibling, @hash]) + + hash = hashRes.valueOr: + return err("Failed to compute poseidon hash: " & error) + hash = hashRes.get() + + return ok(hash) + method generateProof*( g: OnchainGroupManager, data: seq[byte], @@ -372,12 +397,40 @@ method generateProof*( if (g.merkleProofCache.len mod 32) != 0: return err("Invalid merkle proof cache length") - g.merkleProofCache.reverse() - var i = 0 - while i + 31 < g.merkleProofCache.len: - for j in countdown(31, 0): - path_elements.add(g.merkleProofCache[i+j]) - i += 32 + # Proposed fix using index bits + let identity_path_index = uint64ToIndex(g.membershipIndex.get(), 20) + # 20-bit for depth 20 + var pathIndex = 0 + for i in 0 ..< g.merkleProofCache.len div 32: + let bit = identity_path_index[i] + let chunk = g.merkleProofCache[i * 32 .. (i + 1) * 32 - 1] + path_elements.add( + if bit == 0: + chunk.reversed() + else: + chunk + ) + + # After proof generation, verify against contract root + + var generatedRoot: array[32, byte] + try: + let generatedRootRes = g.poseidonHash(path_elements, identity_path_index) + generatedRoot = generatedRootRes.get() + except CatchableError: + error "Failed to update roots", error = getCurrentExceptionMsg() + + var contractRoot: array[32, byte] + try: + let contractRootRes = waitFor g.fetchMerkleRoot() + if contractRootRes.isErr(): + return err("Failed to fetch Merkle proof: " & contractRootRes.error) + contractRoot = UInt256ToField(contractRootRes.get()) + except CatchableError: + error "Failed to update roots", error = getCurrentExceptionMsg() + + if contractRoot != generatedRoot: + return err("Root mismatch: contract=" & $contractRoot & " local=" & $generatedRoot) debug "--- pathElements ---", before = g.merkleProofCache, @@ -385,9 +438,6 @@ method generateProof*( before_len = g.merkleProofCache.len, after_len = path_elements.len - let index_len = int(g.merkleProofCache.len / 32) - let identity_path_index = uint64ToIndex(uint64(g.membershipIndex.get()), index_len) - debug "--- identityPathIndex ---", before = g.membershipIndex.get(), after = identity_path_index,