diff --git a/codex-plonky2-circuits/src/circuits/merkle_circuit.rs b/codex-plonky2-circuits/src/circuits/merkle_circuit.rs index 51f21cb..ea02916 100755 --- a/codex-plonky2-circuits/src/circuits/merkle_circuit.rs +++ b/codex-plonky2-circuits/src/circuits/merkle_circuit.rs @@ -1,10 +1,10 @@ -// Plonky2 Circuit implementation of "safe" merkle tree +// Plonky2 Circuit implementation of the Codex-specific "safe" merkle tree // consistent with the one in codex: // https://github.com/codex-storage/codex-storage-proofs-circuits/blob/master/circuit/codex/merkle.circom use plonky2::{ field::extension::Extendable, - hash::hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS}, + hash::hash_types::{HashOutTarget, RichField}, iop::target::BoolTarget, plonk::{ circuit_builder::CircuitBuilder, @@ -16,7 +16,7 @@ use serde::{Deserialize, Serialize}; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use crate::circuits::keyed_compress::key_compress_circuit; use crate::circuits::serialization::SerializableHashOutTarget; -use crate::circuits::utils::{add_assign_hash_out_target, mul_hash_out_target}; +use crate::circuits::utils::{add_assign_hash_out_target, mul_hash_out_target, select_hash}; use crate::Result; use crate::error::CircuitError; @@ -27,6 +27,11 @@ pub const KEY_ODD: u64 = 0x2; pub const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3; /// Merkle tree targets representing the input to the circuit +/// * `leaf`: the leaf hash +/// * `path_bits`: the linear index of the leaf, in binary decomposition (least significant bit first) +/// * `last_bits`: the index of the last leaf (= nLeaves-1), in binary decomposition +/// * `mask_bits`: the bits of the mask `2^ceilingLog2(size) - 1` +/// * `merkle_path`: the Merkle inclusion proof (required hashes, starting from the leaf and ending near the root) #[derive(Clone)] pub struct MerkleTreeTargets{ pub leaf: HashOutTarget, @@ -42,8 +47,7 @@ pub struct MerkleProofTarget { pub path: Vec, } -/// Merkle tree circuit contains the functions for -/// building, proving and verifying the circuit. +/// contains the functions for reconstructing the Merkle root and returns it. #[derive(Clone)] pub struct MerkleTreeCircuit< F: RichField + Extendable + Poseidon2, @@ -89,7 +93,7 @@ impl< let one = builder.one(); let two = builder.two(); - // --- Basic checks on input sizes. + // --- Basic checks on input sizes ------- let path_len = targets.path_bits.len(); let proof_len = targets.merkle_path.path.len(); let mask_len = targets.mask_bits.len(); @@ -111,16 +115,31 @@ impl< return Err(CircuitError::PathBitsMaxDepthMismatch(path_len, max_depth)); } - // compute is_last - let mut is_last = vec![BoolTarget::new_unsafe(zero); max_depth + 1]; - is_last[max_depth] = BoolTarget::new_unsafe(one); // set isLast[max_depth] to 1 (true) + // in case of a singleton tree, we receive maskBits = [0,0,0,...,0] + // but what we really need is [1,0,0,0,...,0] + // because we always expect [1,1,...,1,0,0,...,0], + // we can just set the first entry to 1 and that should fix this issue. + let mut mask_bit_corrected: Vec = targets.mask_bits.clone(); + mask_bit_corrected[0] = builder.constant_bool(true); + + // ------ Compute is_last -------- + // Determine whether nodes from the path are last in their row and are odd, + // by computing which binary prefixes of the index are the same as the + // corresponding prefix of the last index. + // This is done in reverse bit order, because pathBits and lastBits have the + // least significant bit first. + let mut is_last: Vec = vec![builder.constant_bool(false); max_depth + 1]; + is_last[max_depth] = builder.constant_bool(true); for i in (0..max_depth).rev() { let eq_out = builder.is_equal(targets.path_bits[i].target , targets.last_bits[i].target); is_last[i] = builder.and( is_last[i + 1] , eq_out); } - let mut i: usize = 0; - for (&bit, &sibling) in targets.path_bits.iter().zip(&targets.merkle_path.path) { + // ------ Compute the sequence of hashes -------- + for i in 0..path_len { + + let bit = targets.path_bits[i]; + let sibling = targets.merkle_path.path[i]; // logic: we add KEY_BOTTOM_LAYER if i == 0, otherwise KEY_NONE. let bottom_key_val = if i == 0 { @@ -138,28 +157,23 @@ impl< let key = builder.add(bottom,odd); // select left and right based on path_bit - let mut left = vec![]; - let mut right = vec![]; - for j in 0..NUM_HASH_OUT_ELTS { - left.push( builder.select(bit, sibling.0.elements[j], state[i].elements[j])); - right.push( builder.select(bit, state[i].elements[j], sibling.0.elements[j])); - } + let left = select_hash(builder, bit, sibling.0, state[i]); + let right = select_hash(builder, bit,state[i], sibling.0); // Compress them with a keyed-hash function let combined_hash = key_compress_circuit:: (builder, - HashOutTarget::from_vec(left), - HashOutTarget::from_vec(right), + left, + right, key); state.push(combined_hash); - i += 1; } // select the right layer using the mask bits - let mut reconstructed_root = HashOutTarget::from_vec([builder.zero();4].to_vec()); + let mut reconstructed_root = HashOutTarget::from_vec([zero;4].to_vec()); for k in 0..max_depth { - let diff = builder.sub(targets.mask_bits[k].target, targets.mask_bits[k+1].target); + let diff = builder.sub(mask_bit_corrected[k].target, mask_bit_corrected[k+1].target); let mul_result = mul_hash_out_target(builder,&diff,&mut state[k+1]); add_assign_hash_out_target(builder,&mut reconstructed_root, &mul_result); } @@ -168,3 +182,4 @@ impl< } } +