diff --git a/codex-plonky2-circuits/src/circuits/prove_single_cell.rs b/codex-plonky2-circuits/src/circuits/prove_single_cell.rs index 5fe9b7c..bed9629 100644 --- a/codex-plonky2-circuits/src/circuits/prove_single_cell.rs +++ b/codex-plonky2-circuits/src/circuits/prove_single_cell.rs @@ -282,6 +282,11 @@ impl< let block_last_bits = (0..BOT_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); let slot_last_bits = (0..(depth-BOT_DEPTH)).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); + // mask bits (binary decomposition of last_index = nleaves - 1) + let block_mask_bits = (0..BOT_DEPTH+1).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); + let slot_mask_bits = (0..(depth-BOT_DEPTH)+1).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); + + // Merkle path (sibling hashes from leaf to root) let mut block_merkle_path = MerkleProofTarget { path: (0..BOT_DEPTH).map(|_| builder.add_virtual_hash()).collect(), @@ -297,6 +302,7 @@ impl< leaf: leaf_hash, path_bits:block_path_bits, last_bits: block_last_bits, + mask_bits: block_mask_bits, merkle_path: block_merkle_path, }; @@ -308,6 +314,7 @@ impl< leaf: block_root, path_bits:slot_path_bits, last_bits:slot_last_bits, + mask_bits:slot_mask_bits, merkle_path:slot_merkle_path, }; diff --git a/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs b/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs index d74a935..1040cc4 100644 --- a/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs +++ b/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs @@ -16,10 +16,11 @@ use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitDa use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig}; use plonky2::plonk::proof::{Proof, ProofWithPublicInputs}; use std::marker::PhantomData; +use std::os::macos::raw::stat; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use serde::Serialize; use crate::circuits::keyed_compress::key_compress_circuit; -use crate::circuits::params::HF; +use crate::circuits::params::{HF, MAX_DEPTH}; use crate::circuits::utils::usize_to_bits_le_padded; use crate::merkle_tree::merkle_safe::{MerkleTree, MerkleProofTarget}; @@ -35,6 +36,7 @@ pub struct MerkleTreeTargets{ pub leaf: HashOutTarget, pub path_bits: Vec, pub last_bits: Vec, + pub mask_bits: Vec, pub merkle_path: MerkleProofTarget, } @@ -65,14 +67,17 @@ impl< let leaf = builder.add_virtual_hash(); // path bits (binary decomposition of leaf_index) - let path_bits = (0..depth).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); + let path_bits = (0..MAX_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); // last bits (binary decomposition of last_index = nleaves - 1) - let last_bits = (0..depth).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); + let last_bits = (0..MAX_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); + + // last bits (binary decomposition of last_index = nleaves - 1) + let mask_bits = (0..MAX_DEPTH+1).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); // Merkle path (sibling hashes from leaf to root) let merkle_path = MerkleProofTarget { - path: (0..depth).map(|_| builder.add_virtual_hash()).collect(), + path: (0..MAX_DEPTH).map(|_| builder.add_virtual_hash()).collect(), }; // create MerkleTreeTargets struct @@ -80,11 +85,12 @@ impl< leaf, path_bits, last_bits, + mask_bits, merkle_path, }; // Add Merkle proof verification constraints to the circuit - let expected_root_target = Self::reconstruct_merkle_root_circuit(builder, &mut targets); + let expected_root_target = Self::reconstruct_merkle_root_circuit_with_mask(builder, &mut targets); // Return MerkleTreeTargets (targets, expected_root_target) @@ -112,23 +118,36 @@ impl< pw.set_hash_target(targets.leaf, leaf_hash); // Convert `leaf_index` to binary bits and assign as path_bits - let path_bits = usize_to_bits_le_padded(leaf_index, depth); + let path_bits = usize_to_bits_le_padded(leaf_index, MAX_DEPTH); for (i, bit) in path_bits.iter().enumerate() { pw.set_bool_target(targets.path_bits[i], *bit); } // get `last_index` (nleaves - 1) in binary bits and assign let last_index = nleaves - 1; - let last_bits = usize_to_bits_le_padded(last_index, depth); + let last_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH); for (i, bit) in last_bits.iter().enumerate() { pw.set_bool_target(targets.last_bits[i], *bit); } + // get mask bits + let mask_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH+1); + for (i, bit) in mask_bits.iter().enumerate() { + pw.set_bool_target(targets.mask_bits[i], *bit); + } + // assign the Merkle path (sibling hashes) to the targets - for (i, sibling_hash) in proof.path.iter().enumerate() { + for i in 0..MAX_DEPTH { + if i>=proof.path.len() { + for j in 0..NUM_HASH_OUT_ELTS { + pw.set_target(targets.merkle_path.path[i].elements[j], F::ZERO); + } + continue + } // This is a bit hacky because it should be HashOutTarget, but it is H:Hash // pw.set_hash_target(targets.merkle_path.path[i],sibling_hash); // TODO: fix this HashOutTarget later + let sibling_hash = proof.path[i]; let sibling_hash_out = sibling_hash.to_vec(); for j in 0..sibling_hash_out.len() { pw.set_target(targets.merkle_path.path[i].elements[j], sibling_hash_out[j]); @@ -191,6 +210,89 @@ impl< return state; } + + /// takes the params from the targets struct + /// outputs the reconstructed merkle root + /// this one uses the mask bits + pub fn reconstruct_merkle_root_circuit_with_mask( + builder: &mut CircuitBuilder, + targets: &mut MerkleTreeTargets, + ) -> HashOutTarget { + let max_depth = targets.path_bits.len(); + let mut state: Vec = Vec::with_capacity(max_depth+1); + state.push(targets.leaf); + let zero = builder.zero(); + let one = builder.one(); + let two = builder.two(); + debug_assert_eq!(targets.path_bits.len(), targets.merkle_path.path.len()); + + // compute is_last + let mut is_last = vec![BoolTarget::new_unsafe(zero); max_depth + 1]; + is_last[max_depth] = BoolTarget::new_unsafe(one); // set isLast[max_depth] to 1 (true) + for i in (0..max_depth).rev() { + let eq_out = builder.is_equal(targets.path_bits[i].target , targets.last_bits[i].target); + is_last[i] = builder.and( is_last[i + 1] , eq_out); + } + + let mut i: usize = 0; + for (&bit, &sibling) in targets.path_bits.iter().zip(&targets.merkle_path.path) { + debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS); + + let bottom = if i == 0 { + builder.constant(F::from_canonical_u64(KEY_BOTTOM_LAYER)) + } else { + builder.constant(F::from_canonical_u64(KEY_NONE)) + }; + + // compute: odd = isLast[i] * (1-pathBits[i]); + // compute: key = bottom + 2*odd + let mut odd = builder.sub(one, targets.path_bits[i].target); + odd = builder.mul(is_last[i].target, odd); + odd = builder.mul(two, odd); + let key = builder.add(bottom,odd); + + // select left and right based on path_bit + let mut left = vec![]; + let mut right = vec![]; + for j in 0..NUM_HASH_OUT_ELTS { + left.push( builder.select(bit, sibling.elements[j], state[i].elements[j])); + right.push( builder.select(bit, state[i].elements[j], sibling.elements[j])); + } + + state.push(key_compress_circuit::(builder,left,right,key)); + + i += 1; + } + + // println!("mask = {}, last = {}", targets.mask_bits.len(), targets.last_bits.len(), ); + + // another way to do this is to use builder.select + // but that might be less efficient & more constraints + let mut reconstructed_root = HashOutTarget::from_vec([builder.zero();4].to_vec()); + for k in 0..MAX_DEPTH { + let diff = builder.sub(targets.mask_bits[k].target, targets.mask_bits[k+1].target); + let mul_result = Self::mul_hash_out_target(builder,&diff,&mut state[k+1]); + Self::add_assign_hash_out_target(builder,&mut reconstructed_root, &mul_result); + } + + reconstructed_root + } + + /// helper fn to multiply a HashOutTarget by a Target + pub(crate) fn mul_hash_out_target(builder: &mut CircuitBuilder, t: &Target, hash_target: &mut HashOutTarget) -> HashOutTarget { + let mut mul_elements = vec![]; + for i in 0..NUM_HASH_OUT_ELTS { + mul_elements.push(builder.mul(hash_target.elements[i], *t)); + } + HashOutTarget::from_vec(mul_elements) + } + + /// helper fn to add AND assign a HashOutTarget (hot) to a mutable HashOutTarget (mut_hot) + pub(crate) fn add_assign_hash_out_target(builder: &mut CircuitBuilder, mut_hot: &mut HashOutTarget, hot: &HashOutTarget) { + for i in 0..NUM_HASH_OUT_ELTS { + mut_hot.elements[i] = (builder.add(mut_hot.elements[i], hot.elements[i])); + } + } } // NOTE: for now these tests don't check the reconstructed root is equal to expected_root diff --git a/codex-plonky2-circuits/src/circuits/sample_cells.rs b/codex-plonky2-circuits/src/circuits/sample_cells.rs index f351105..95cca06 100644 --- a/codex-plonky2-circuits/src/circuits/sample_cells.rs +++ b/codex-plonky2-circuits/src/circuits/sample_cells.rs @@ -245,6 +245,8 @@ impl< d_last_index = builder.sub(d_last_index, one); let d_last_bits = builder.split_le(d_last_index,d_depth); + let d_mask_bits = builder.split_le(d_last_index,d_depth+1); + // dataset Merkle path (sibling hashes from leaf to root) let d_merkle_path = MerkleProofTarget { path: (0..d_depth).map(|_| builder.add_virtual_hash()).collect(), @@ -255,6 +257,7 @@ impl< leaf: slot_root, path_bits: d_path_bits, last_bits: d_last_bits, + mask_bits: d_mask_bits, merkle_path: d_merkle_path, }; @@ -282,11 +285,15 @@ impl< b_last_index = builder.sub(b_last_index, one); let b_last_bits = builder.split_le(b_last_index,BOT_DEPTH); + let b_mask_bits = builder.split_le(b_last_index,BOT_DEPTH+1); + let s_depth_target = builder.constant(F::from_canonical_u64(MAX_DEPTH as u64)); let mut s_last_index = builder.exp(two,s_depth_target,MAX_DEPTH); s_last_index = builder.sub(s_last_index, one); let s_last_bits = builder.split_le(s_last_index,MAX_DEPTH); + let s_mask_bits = builder.split_le(b_last_index,BOT_DEPTH+1); + for i in 0..N_SAMPLES{ // cell data targets let mut data_i = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::>(); @@ -320,21 +327,23 @@ impl< leaf: data_i_hash, path_bits:b_path_bits, last_bits: b_last_bits.clone(), + mask_bits: b_mask_bits.clone(), merkle_path: b_merkle_path, }; // reconstruct block root - let b_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit(builder, &mut block_targets); + let b_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut block_targets); let mut slot_targets = MerkleTreeTargets { leaf: b_root, path_bits:s_path_bits, last_bits:s_last_bits.clone(), + mask_bits:s_mask_bits.clone(), merkle_path:s_merkle_path, }; // reconstruct slot root with block root as leaf - let slot_reconstructed_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit(builder, &mut slot_targets); + let slot_reconstructed_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut slot_targets); // check equality with expected root for i in 0..NUM_HASH_OUT_ELTS {