add bit mask

This commit is contained in:
M Alghazwi 2024-11-03 11:50:46 +01:00
parent 4d723f8c7f
commit 4c3f2043ee
3 changed files with 128 additions and 10 deletions

View File

@ -282,6 +282,11 @@ impl<
let block_last_bits = (0..BOT_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>(); let block_last_bits = (0..BOT_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
let slot_last_bits = (0..(depth-BOT_DEPTH)).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>(); let slot_last_bits = (0..(depth-BOT_DEPTH)).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// mask bits (binary decomposition of last_index = nleaves - 1)
let block_mask_bits = (0..BOT_DEPTH+1).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
let slot_mask_bits = (0..(depth-BOT_DEPTH)+1).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// Merkle path (sibling hashes from leaf to root) // Merkle path (sibling hashes from leaf to root)
let mut block_merkle_path = MerkleProofTarget { let mut block_merkle_path = MerkleProofTarget {
path: (0..BOT_DEPTH).map(|_| builder.add_virtual_hash()).collect(), path: (0..BOT_DEPTH).map(|_| builder.add_virtual_hash()).collect(),
@ -297,6 +302,7 @@ impl<
leaf: leaf_hash, leaf: leaf_hash,
path_bits:block_path_bits, path_bits:block_path_bits,
last_bits: block_last_bits, last_bits: block_last_bits,
mask_bits: block_mask_bits,
merkle_path: block_merkle_path, merkle_path: block_merkle_path,
}; };
@ -308,6 +314,7 @@ impl<
leaf: block_root, leaf: block_root,
path_bits:slot_path_bits, path_bits:slot_path_bits,
last_bits:slot_last_bits, last_bits:slot_last_bits,
mask_bits:slot_mask_bits,
merkle_path:slot_merkle_path, merkle_path:slot_merkle_path,
}; };

View File

@ -16,10 +16,11 @@ use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitDa
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig}; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig};
use plonky2::plonk::proof::{Proof, ProofWithPublicInputs}; use plonky2::plonk::proof::{Proof, ProofWithPublicInputs};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::os::macos::raw::stat;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use serde::Serialize; use serde::Serialize;
use crate::circuits::keyed_compress::key_compress_circuit; use crate::circuits::keyed_compress::key_compress_circuit;
use crate::circuits::params::HF; use crate::circuits::params::{HF, MAX_DEPTH};
use crate::circuits::utils::usize_to_bits_le_padded; use crate::circuits::utils::usize_to_bits_le_padded;
use crate::merkle_tree::merkle_safe::{MerkleTree, MerkleProofTarget}; use crate::merkle_tree::merkle_safe::{MerkleTree, MerkleProofTarget};
@ -35,6 +36,7 @@ pub struct MerkleTreeTargets{
pub leaf: HashOutTarget, pub leaf: HashOutTarget,
pub path_bits: Vec<BoolTarget>, pub path_bits: Vec<BoolTarget>,
pub last_bits: Vec<BoolTarget>, pub last_bits: Vec<BoolTarget>,
pub mask_bits: Vec<BoolTarget>,
pub merkle_path: MerkleProofTarget, pub merkle_path: MerkleProofTarget,
} }
@ -65,14 +67,17 @@ impl<
let leaf = builder.add_virtual_hash(); let leaf = builder.add_virtual_hash();
// path bits (binary decomposition of leaf_index) // path bits (binary decomposition of leaf_index)
let path_bits = (0..depth).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>(); let path_bits = (0..MAX_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// last bits (binary decomposition of last_index = nleaves - 1) // last bits (binary decomposition of last_index = nleaves - 1)
let last_bits = (0..depth).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>(); let last_bits = (0..MAX_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// last bits (binary decomposition of last_index = nleaves - 1)
let mask_bits = (0..MAX_DEPTH+1).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// Merkle path (sibling hashes from leaf to root) // Merkle path (sibling hashes from leaf to root)
let merkle_path = MerkleProofTarget { let merkle_path = MerkleProofTarget {
path: (0..depth).map(|_| builder.add_virtual_hash()).collect(), path: (0..MAX_DEPTH).map(|_| builder.add_virtual_hash()).collect(),
}; };
// create MerkleTreeTargets struct // create MerkleTreeTargets struct
@ -80,11 +85,12 @@ impl<
leaf, leaf,
path_bits, path_bits,
last_bits, last_bits,
mask_bits,
merkle_path, merkle_path,
}; };
// Add Merkle proof verification constraints to the circuit // Add Merkle proof verification constraints to the circuit
let expected_root_target = Self::reconstruct_merkle_root_circuit(builder, &mut targets); let expected_root_target = Self::reconstruct_merkle_root_circuit_with_mask(builder, &mut targets);
// Return MerkleTreeTargets // Return MerkleTreeTargets
(targets, expected_root_target) (targets, expected_root_target)
@ -112,23 +118,36 @@ impl<
pw.set_hash_target(targets.leaf, leaf_hash); pw.set_hash_target(targets.leaf, leaf_hash);
// Convert `leaf_index` to binary bits and assign as path_bits // Convert `leaf_index` to binary bits and assign as path_bits
let path_bits = usize_to_bits_le_padded(leaf_index, depth); let path_bits = usize_to_bits_le_padded(leaf_index, MAX_DEPTH);
for (i, bit) in path_bits.iter().enumerate() { for (i, bit) in path_bits.iter().enumerate() {
pw.set_bool_target(targets.path_bits[i], *bit); pw.set_bool_target(targets.path_bits[i], *bit);
} }
// get `last_index` (nleaves - 1) in binary bits and assign // get `last_index` (nleaves - 1) in binary bits and assign
let last_index = nleaves - 1; let last_index = nleaves - 1;
let last_bits = usize_to_bits_le_padded(last_index, depth); let last_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH);
for (i, bit) in last_bits.iter().enumerate() { for (i, bit) in last_bits.iter().enumerate() {
pw.set_bool_target(targets.last_bits[i], *bit); pw.set_bool_target(targets.last_bits[i], *bit);
} }
// get mask bits
let mask_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH+1);
for (i, bit) in mask_bits.iter().enumerate() {
pw.set_bool_target(targets.mask_bits[i], *bit);
}
// assign the Merkle path (sibling hashes) to the targets // assign the Merkle path (sibling hashes) to the targets
for (i, sibling_hash) in proof.path.iter().enumerate() { for i in 0..MAX_DEPTH {
if i>=proof.path.len() {
for j in 0..NUM_HASH_OUT_ELTS {
pw.set_target(targets.merkle_path.path[i].elements[j], F::ZERO);
}
continue
}
// This is a bit hacky because it should be HashOutTarget, but it is H:Hash // This is a bit hacky because it should be HashOutTarget, but it is H:Hash
// pw.set_hash_target(targets.merkle_path.path[i],sibling_hash); // pw.set_hash_target(targets.merkle_path.path[i],sibling_hash);
// TODO: fix this HashOutTarget later // TODO: fix this HashOutTarget later
let sibling_hash = proof.path[i];
let sibling_hash_out = sibling_hash.to_vec(); let sibling_hash_out = sibling_hash.to_vec();
for j in 0..sibling_hash_out.len() { for j in 0..sibling_hash_out.len() {
pw.set_target(targets.merkle_path.path[i].elements[j], sibling_hash_out[j]); pw.set_target(targets.merkle_path.path[i].elements[j], sibling_hash_out[j]);
@ -191,6 +210,89 @@ impl<
return state; return state;
} }
/// takes the params from the targets struct
/// outputs the reconstructed merkle root
/// this one uses the mask bits
pub fn reconstruct_merkle_root_circuit_with_mask(
builder: &mut CircuitBuilder<F, D>,
targets: &mut MerkleTreeTargets,
) -> HashOutTarget {
let max_depth = targets.path_bits.len();
let mut state: Vec<HashOutTarget> = Vec::with_capacity(max_depth+1);
state.push(targets.leaf);
let zero = builder.zero();
let one = builder.one();
let two = builder.two();
debug_assert_eq!(targets.path_bits.len(), targets.merkle_path.path.len());
// compute is_last
let mut is_last = vec![BoolTarget::new_unsafe(zero); max_depth + 1];
is_last[max_depth] = BoolTarget::new_unsafe(one); // set isLast[max_depth] to 1 (true)
for i in (0..max_depth).rev() {
let eq_out = builder.is_equal(targets.path_bits[i].target , targets.last_bits[i].target);
is_last[i] = builder.and( is_last[i + 1] , eq_out);
}
let mut i: usize = 0;
for (&bit, &sibling) in targets.path_bits.iter().zip(&targets.merkle_path.path) {
debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS);
let bottom = if i == 0 {
builder.constant(F::from_canonical_u64(KEY_BOTTOM_LAYER))
} else {
builder.constant(F::from_canonical_u64(KEY_NONE))
};
// compute: odd = isLast[i] * (1-pathBits[i]);
// compute: key = bottom + 2*odd
let mut odd = builder.sub(one, targets.path_bits[i].target);
odd = builder.mul(is_last[i].target, odd);
odd = builder.mul(two, odd);
let key = builder.add(bottom,odd);
// select left and right based on path_bit
let mut left = vec![];
let mut right = vec![];
for j in 0..NUM_HASH_OUT_ELTS {
left.push( builder.select(bit, sibling.elements[j], state[i].elements[j]));
right.push( builder.select(bit, state[i].elements[j], sibling.elements[j]));
}
state.push(key_compress_circuit::<F,D,HF>(builder,left,right,key));
i += 1;
}
// println!("mask = {}, last = {}", targets.mask_bits.len(), targets.last_bits.len(), );
// another way to do this is to use builder.select
// but that might be less efficient & more constraints
let mut reconstructed_root = HashOutTarget::from_vec([builder.zero();4].to_vec());
for k in 0..MAX_DEPTH {
let diff = builder.sub(targets.mask_bits[k].target, targets.mask_bits[k+1].target);
let mul_result = Self::mul_hash_out_target(builder,&diff,&mut state[k+1]);
Self::add_assign_hash_out_target(builder,&mut reconstructed_root, &mul_result);
}
reconstructed_root
}
/// helper fn to multiply a HashOutTarget by a Target
pub(crate) fn mul_hash_out_target(builder: &mut CircuitBuilder<F, D>, t: &Target, hash_target: &mut HashOutTarget) -> HashOutTarget {
let mut mul_elements = vec![];
for i in 0..NUM_HASH_OUT_ELTS {
mul_elements.push(builder.mul(hash_target.elements[i], *t));
}
HashOutTarget::from_vec(mul_elements)
}
/// helper fn to add AND assign a HashOutTarget (hot) to a mutable HashOutTarget (mut_hot)
pub(crate) fn add_assign_hash_out_target(builder: &mut CircuitBuilder<F, D>, mut_hot: &mut HashOutTarget, hot: &HashOutTarget) {
for i in 0..NUM_HASH_OUT_ELTS {
mut_hot.elements[i] = (builder.add(mut_hot.elements[i], hot.elements[i]));
}
}
} }
// NOTE: for now these tests don't check the reconstructed root is equal to expected_root // NOTE: for now these tests don't check the reconstructed root is equal to expected_root

View File

@ -245,6 +245,8 @@ impl<
d_last_index = builder.sub(d_last_index, one); d_last_index = builder.sub(d_last_index, one);
let d_last_bits = builder.split_le(d_last_index,d_depth); let d_last_bits = builder.split_le(d_last_index,d_depth);
let d_mask_bits = builder.split_le(d_last_index,d_depth+1);
// dataset Merkle path (sibling hashes from leaf to root) // dataset Merkle path (sibling hashes from leaf to root)
let d_merkle_path = MerkleProofTarget { let d_merkle_path = MerkleProofTarget {
path: (0..d_depth).map(|_| builder.add_virtual_hash()).collect(), path: (0..d_depth).map(|_| builder.add_virtual_hash()).collect(),
@ -255,6 +257,7 @@ impl<
leaf: slot_root, leaf: slot_root,
path_bits: d_path_bits, path_bits: d_path_bits,
last_bits: d_last_bits, last_bits: d_last_bits,
mask_bits: d_mask_bits,
merkle_path: d_merkle_path, merkle_path: d_merkle_path,
}; };
@ -282,11 +285,15 @@ impl<
b_last_index = builder.sub(b_last_index, one); b_last_index = builder.sub(b_last_index, one);
let b_last_bits = builder.split_le(b_last_index,BOT_DEPTH); let b_last_bits = builder.split_le(b_last_index,BOT_DEPTH);
let b_mask_bits = builder.split_le(b_last_index,BOT_DEPTH+1);
let s_depth_target = builder.constant(F::from_canonical_u64(MAX_DEPTH as u64)); let s_depth_target = builder.constant(F::from_canonical_u64(MAX_DEPTH as u64));
let mut s_last_index = builder.exp(two,s_depth_target,MAX_DEPTH); let mut s_last_index = builder.exp(two,s_depth_target,MAX_DEPTH);
s_last_index = builder.sub(s_last_index, one); s_last_index = builder.sub(s_last_index, one);
let s_last_bits = builder.split_le(s_last_index,MAX_DEPTH); let s_last_bits = builder.split_le(s_last_index,MAX_DEPTH);
let s_mask_bits = builder.split_le(b_last_index,BOT_DEPTH+1);
for i in 0..N_SAMPLES{ for i in 0..N_SAMPLES{
// cell data targets // cell data targets
let mut data_i = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::<Vec<_>>(); let mut data_i = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::<Vec<_>>();
@ -320,21 +327,23 @@ impl<
leaf: data_i_hash, leaf: data_i_hash,
path_bits:b_path_bits, path_bits:b_path_bits,
last_bits: b_last_bits.clone(), last_bits: b_last_bits.clone(),
mask_bits: b_mask_bits.clone(),
merkle_path: b_merkle_path, merkle_path: b_merkle_path,
}; };
// reconstruct block root // reconstruct block root
let b_root = MerkleTreeCircuit::<F,D>::reconstruct_merkle_root_circuit(builder, &mut block_targets); let b_root = MerkleTreeCircuit::<F,D>::reconstruct_merkle_root_circuit_with_mask(builder, &mut block_targets);
let mut slot_targets = MerkleTreeTargets { let mut slot_targets = MerkleTreeTargets {
leaf: b_root, leaf: b_root,
path_bits:s_path_bits, path_bits:s_path_bits,
last_bits:s_last_bits.clone(), last_bits:s_last_bits.clone(),
mask_bits:s_mask_bits.clone(),
merkle_path:s_merkle_path, merkle_path:s_merkle_path,
}; };
// reconstruct slot root with block root as leaf // reconstruct slot root with block root as leaf
let slot_reconstructed_root = MerkleTreeCircuit::<F,D>::reconstruct_merkle_root_circuit(builder, &mut slot_targets); let slot_reconstructed_root = MerkleTreeCircuit::<F,D>::reconstruct_merkle_root_circuit_with_mask(builder, &mut slot_targets);
// check equality with expected root // check equality with expected root
for i in 0..NUM_HASH_OUT_ELTS { for i in 0..NUM_HASH_OUT_ELTS {