fix: consolidate util functions

This commit is contained in:
Arseniy Klempner 2025-12-15 16:26:55 -08:00
parent 80bf606270
commit 367ea4a0a8
No known key found for this signature in database
GPG Key ID: 51653F18863BD24B
4 changed files with 35 additions and 131 deletions

View File

@ -5,7 +5,7 @@ import { RLNInstance } from "./rln.js";
import { BytesUtils } from "./utils/index.js"; import { BytesUtils } from "./utils/index.js";
import { import {
calculateRateCommitment, calculateRateCommitment,
extractPathDirectionsFromProof, getPathDirectionsFromIndex,
MERKLE_TREE_DEPTH, MERKLE_TREE_DEPTH,
reconstructMerkleRoot reconstructMerkleRoot
} from "./utils/merkle.js"; } from "./utils/merkle.js";
@ -15,24 +15,23 @@ describe("RLN Proof Integration Tests", function () {
this.timeout(30000); this.timeout(30000);
it("validate stored merkle proof data", function () { it("validate stored merkle proof data", function () {
// Convert stored merkle proof strings to bigints
const merkleProof = TEST_KEYSTORE_DATA.merkleProof.map((p) => BigInt(p)); const merkleProof = TEST_KEYSTORE_DATA.merkleProof.map((p) => BigInt(p));
expect(merkleProof).to.be.an("array"); expect(merkleProof).to.be.an("array");
expect(merkleProof).to.have.lengthOf(MERKLE_TREE_DEPTH); // RLN uses fixed depth merkle tree expect(merkleProof).to.have.lengthOf(MERKLE_TREE_DEPTH);
merkleProof.forEach((element, i) => { for (let i = 0; i < merkleProof.length; i++) {
const element = merkleProof[i];
expect(element).to.be.a( expect(element).to.be.a(
"bigint", "bigint",
`Proof element ${i} should be a bigint` `Proof element ${i} should be a bigint`
); );
expect(element).to.not.equal(0n, `Proof element ${i} should not be zero`); expect(element).to.not.equal(0n, `Proof element ${i} should not be zero`);
}); }
}); });
it("should generate a valid RLN proof", async function () { it("should generate a valid RLN proof", async function () {
const rlnInstance = await RLNInstance.create(); const rlnInstance = await RLNInstance.create();
// Load credential from test keystore
const keystore = Keystore.fromString(TEST_KEYSTORE_DATA.keystoreJson); const keystore = Keystore.fromString(TEST_KEYSTORE_DATA.keystoreJson);
if (!keystore) { if (!keystore) {
throw new Error("Failed to load test keystore"); throw new Error("Failed to load test keystore");
@ -53,14 +52,7 @@ describe("RLN Proof Integration Tests", function () {
const rateCommitment = calculateRateCommitment(idCommitment, rateLimit); const rateCommitment = calculateRateCommitment(idCommitment, rateLimit);
const proofElementIndexes = extractPathDirectionsFromProof( const proofElementIndexes = getPathDirectionsFromIndex(membershipIndex);
merkleProof,
rateCommitment,
merkleRoot
);
if (!proofElementIndexes) {
throw new Error("Failed to extract proof element indexes");
}
expect(proofElementIndexes).to.have.lengthOf(MERKLE_TREE_DEPTH); expect(proofElementIndexes).to.have.lengthOf(MERKLE_TREE_DEPTH);
@ -82,7 +74,9 @@ describe("RLN Proof Integration Tests", function () {
Number(membershipIndex), Number(membershipIndex),
new Date(), new Date(),
credential.identity.IDSecretHash, credential.identity.IDSecretHash,
merkleProof.map((proof) => BytesUtils.fromBigInt(proof, 32, "little")), merkleProof.map((element) =>
BytesUtils.bytes32FromBigInt(element, "little")
),
proofElementIndexes.map((index) => proofElementIndexes.map((index) =>
BytesUtils.writeUIntLE(new Uint8Array(1), index, 0, 1) BytesUtils.writeUIntLE(new Uint8Array(1), index, 0, 1)
), ),
@ -94,7 +88,7 @@ describe("RLN Proof Integration Tests", function () {
BytesUtils.writeUIntLE(new Uint8Array(8), testMessage.length, 0, 8), BytesUtils.writeUIntLE(new Uint8Array(8), testMessage.length, 0, 8),
testMessage, testMessage,
proof, proof,
[BytesUtils.fromBigInt(merkleRoot, 32, "little")] [BytesUtils.bytes32FromBigInt(merkleRoot, "little")]
); );
expect(isValid).to.be.true; expect(isValid).to.be.true;
}); });

View File

@ -50,30 +50,34 @@ export class BytesUtils {
} }
/** /**
* Convert a BigInt to a Uint8Array with configurable output endianness * Convert a BigInt to a bytes32 (32-byte Uint8Array)
* @param value - The BigInt to convert * @param value - The BigInt to convert (must fit in 32 bytes)
* @param byteLength - The desired byte length of the output (optional, auto-calculated if not provided)
* @param outputEndianness - Endianness of the output bytes ('big' or 'little') * @param outputEndianness - Endianness of the output bytes ('big' or 'little')
* @returns Uint8Array representation of the BigInt * @returns 32-byte Uint8Array representation of the BigInt
*/ */
public static fromBigInt( public static bytes32FromBigInt(
value: bigint, value: bigint,
byteLength: number,
outputEndianness: "big" | "little" = "little" outputEndianness: "big" | "little" = "little"
): Uint8Array { ): Uint8Array {
if (value < 0n) { if (value < 0n) {
throw new Error("Cannot convert negative BigInt to bytes"); throw new Error("Cannot convert negative BigInt to bytes");
} }
if (value === 0n) { if (value >> 256n !== 0n) {
return new Uint8Array(byteLength); throw new Error(
`BigInt value is too large to fit in 32 bytes (max bit length: 256)`
);
} }
const result = new Uint8Array(byteLength); if (value === 0n) {
return new Uint8Array(32);
}
const result = new Uint8Array(32);
let workingValue = value; let workingValue = value;
// Extract bytes in big-endian order // Extract bytes in big-endian order
for (let i = byteLength - 1; i >= 0; i--) { for (let i = 31; i >= 0; i--) {
result[i] = Number(workingValue & 0xffn); result[i] = Number(workingValue & 0xffn);
workingValue = workingValue >> 8n; workingValue = workingValue >> 8n;
} }

View File

@ -26,64 +26,23 @@ export function reconstructMerkleRoot(
); );
} }
let currentValue = leafValue; let currentValue = BytesUtils.bytes32FromBigInt(leafValue);
// Process each level of the tree (0 to MERKLE_TREE_DEPTH-1)
for (let level = 0; level < MERKLE_TREE_DEPTH; level++) { for (let level = 0; level < MERKLE_TREE_DEPTH; level++) {
// Check if bit `level` is set in the leaf index
const bit = (leafIndex >> BigInt(level)) & 1n; const bit = (leafIndex >> BigInt(level)) & 1n;
// Convert bigints to Uint8Array for hashing const proofBytes = BytesUtils.bytes32FromBigInt(proof[level]);
const currentBytes = bigIntToBytes32(currentValue);
const proofBytes = bigIntToBytes32(proof[level]);
let hashResult: Uint8Array;
if (bit === 0n) { if (bit === 0n) {
// Current node is a left child: hash(current, proof[level]) // Current node is a left child: hash(current, proof[level])
hashResult = poseidonHash(currentBytes, proofBytes); currentValue = poseidonHash(currentValue, proofBytes);
} else { } else {
// Current node is a right child: hash(proof[level], current) // Current node is a right child: hash(proof[level], current)
hashResult = poseidonHash(proofBytes, currentBytes); currentValue = poseidonHash(proofBytes, currentValue);
}
// Convert hash result back to bigint for next iteration
currentValue = BytesUtils.toBigInt(hashResult, "little");
}
return currentValue;
}
/**
* Extracts index information from a Merkle proof by attempting to reconstruct
* the root with different possible indices and comparing against the expected root
*
* @param proof - Array of MERKLE_TREE_DEPTH bigint elements representing the Merkle proof
* @param leafValue - The value of the leaf (typically the rate commitment)
* @param expectedRoot - The expected root to match against
* @param maxIndex - Maximum index to try (default: 2^MERKLE_TREE_DEPTH - 1)
* @returns The index that produces the expected root, or null if not found
*/
function extractIndexFromProof(
proof: readonly bigint[],
leafValue: bigint,
expectedRoot: bigint,
maxIndex: bigint = (1n << BigInt(MERKLE_TREE_DEPTH)) - 1n
): bigint | null {
// Try different indices to see which one produces the expected root
for (let index = 0n; index <= maxIndex; index++) {
try {
const reconstructedRoot = reconstructMerkleRoot(proof, index, leafValue);
if (reconstructedRoot === expectedRoot) {
return index;
}
} catch (error) {
// Continue trying other indices if reconstruction fails
continue;
} }
} }
return null; return BytesUtils.toBigInt(currentValue, "little");
} }
/** /**
@ -98,65 +57,13 @@ export function calculateRateCommitment(
idCommitment: bigint, idCommitment: bigint,
rateLimit: bigint rateLimit: bigint
): bigint { ): bigint {
const idBytes = bigIntToBytes32(idCommitment); const idBytes = BytesUtils.bytes32FromBigInt(idCommitment);
const rateLimitBytes = bigIntToBytes32(rateLimit); const rateLimitBytes = BytesUtils.bytes32FromBigInt(rateLimit);
const hashResult = poseidonHash(idBytes, rateLimitBytes); const hashResult = poseidonHash(idBytes, rateLimitBytes);
return BytesUtils.toBigInt(hashResult, "little"); return BytesUtils.toBigInt(hashResult, "little");
} }
/**
* Converts a bigint to a 32-byte Uint8Array in little-endian format
*
* @param value - The bigint value to convert
* @returns 32-byte Uint8Array representation
*/
function bigIntToBytes32(value: bigint): Uint8Array {
const bytes = new Uint8Array(32);
let temp = value;
for (let i = 0; i < 32; i++) {
bytes[i] = Number(temp & 0xffn);
temp >>= 8n;
}
return bytes;
}
/**
* Extracts the path direction bits from a Merkle proof by finding the leaf index
* that produces the expected root, then converting that index to path directions
*
* @param proof - Array of MERKLE_TREE_DEPTH bigint elements representing the Merkle proof
* @param leafValue - The value of the leaf (typically the rate commitment)
* @param expectedRoot - The expected root to match against
* @param maxIndex - Maximum index to try (default: 2^MERKLE_TREE_DEPTH - 1)
* @returns Array of MERKLE_TREE_DEPTH numbers (0 or 1) representing path directions, or null if no valid path found
* - 0 means the node is a left child (hash order: current, sibling)
* - 1 means the node is a right child (hash order: sibling, current)
*/
export function extractPathDirectionsFromProof(
proof: readonly bigint[],
leafValue: bigint,
expectedRoot: bigint,
maxIndex: bigint = (1n << BigInt(MERKLE_TREE_DEPTH)) - 1n
): number[] | null {
// First, find the leaf index that produces the expected root
const leafIndex = extractIndexFromProof(
proof,
leafValue,
expectedRoot,
maxIndex
);
if (leafIndex === null) {
return null;
}
// Convert the leaf index to path directions
return getPathDirectionsFromIndex(leafIndex);
}
/** /**
* Converts a leaf index to an array of path direction bits * Converts a leaf index to an array of path direction bits
* *
@ -165,7 +72,7 @@ export function extractPathDirectionsFromProof(
* - 0 means the node is a left child (hash order: current, sibling) * - 0 means the node is a left child (hash order: current, sibling)
* - 1 means the node is a right child (hash order: sibling, current) * - 1 means the node is a right child (hash order: sibling, current)
*/ */
function getPathDirectionsFromIndex(leafIndex: bigint): number[] { export function getPathDirectionsFromIndex(leafIndex: bigint): number[] {
const pathDirections: number[] = []; const pathDirections: number[] = [];
// For each level (0 to MERKLE_TREE_DEPTH-1), extract the bit that determines left/right // For each level (0 to MERKLE_TREE_DEPTH-1), extract the bit that determines left/right

View File

@ -41,7 +41,7 @@ export class Zerokit {
idSecretHash: Uint8Array, idSecretHash: Uint8Array,
pathElements: Uint8Array[], pathElements: Uint8Array[],
identityPathIndex: Uint8Array[], identityPathIndex: Uint8Array[],
x: Uint8Array, msg: Uint8Array,
epoch: Uint8Array, epoch: Uint8Array,
rateLimit: number, rateLimit: number,
messageId: number // number of message sent by the user in this epoch messageId: number // number of message sent by the user in this epoch
@ -69,6 +69,7 @@ export class Zerokit {
// We assume that each identity path index is already in little-endian format // We assume that each identity path index is already in little-endian format
identityPathIndexBytes.set(identityPathIndex[i], 8 + i * 1); identityPathIndexBytes.set(identityPathIndex[i], 8 + i * 1);
} }
const x = sha256(msg);
return BytesUtils.concatenate( return BytesUtils.concatenate(
idSecretHash, idSecretHash,
BytesUtils.writeUIntLE(new Uint8Array(32), rateLimit, 0, 32), BytesUtils.writeUIntLE(new Uint8Array(32), rateLimit, 0, 32),
@ -108,13 +109,11 @@ export class Zerokit {
); );
} }
const x = sha256(msg);
const serializedWitness = await this.serializeWitness( const serializedWitness = await this.serializeWitness(
idSecretHash, idSecretHash,
pathElements, pathElements,
identityPathIndex, identityPathIndex,
x, msg,
epoch, epoch,
rateLimit, rateLimit,
messageId messageId