add entropy to cell sampling

This commit is contained in:
M Alghazwi 2024-10-22 13:56:39 +02:00
parent 3c5a130a70
commit 76ed312e2c
3 changed files with 180 additions and 32 deletions

View File

@ -207,8 +207,6 @@ impl<
return state; return state;
} }
} }
// NOTE: for now these tests don't check the reconstructed root is equal to expected_root // NOTE: for now these tests don't check the reconstructed root is equal to expected_root

View File

@ -26,7 +26,7 @@ use plonky2::plonk::config::PoseidonGoldilocksConfig;
use plonky2::hash::hashing::PlonkyPermutation; use plonky2::hash::hashing::PlonkyPermutation;
use crate::circuits::prove_single_cell::{SingleCellTargets, SlotTreeCircuit}; use crate::circuits::prove_single_cell::{SingleCellTargets, SlotTreeCircuit};
use crate::circuits::params::{ DATASET_DEPTH, N_SAMPLES, TESTING_SLOT_INDEX}; use crate::circuits::params::{BOT_DEPTH, DATASET_DEPTH, MAX_DEPTH, N_FIELD_ELEMS_PER_CELL, N_SAMPLES, TESTING_SLOT_INDEX};
use crate::circuits::safe_tree_circuit::{MerkleTreeCircuit, MerkleTreeTargets}; use crate::circuits::safe_tree_circuit::{MerkleTreeCircuit, MerkleTreeTargets};
use crate::circuits::utils::{bits_le_padded_to_usize, calculate_cell_index_bits}; use crate::circuits::utils::{bits_le_padded_to_usize, calculate_cell_index_bits};
@ -203,8 +203,6 @@ impl<
} }
} }
//------- single cell struct ------
#[derive(Clone)] #[derive(Clone)]
pub struct DatasetTargets< pub struct DatasetTargets<
F: RichField + Extendable<D> + Poseidon2, F: RichField + Extendable<D> + Poseidon2,
@ -212,9 +210,14 @@ pub struct DatasetTargets<
const D: usize, const D: usize,
H: Hasher<F> + AlgebraicHasher<F>, H: Hasher<F> + AlgebraicHasher<F>,
> { > {
pub dataset_proof: MerkleTreeTargets<F, C, D, H>, pub dataset_proof: MerkleProofTarget, // proof that slot_root in dataset tree
pub dataset_root: HashOutTarget, pub dataset_root: HashOutTarget,
pub slot_proofs: Vec<SingleCellTargets<F, C, D, H>>,
pub cell_data: Vec<Vec<Target>>,
pub entropy:Target,
pub slot_index: Target,
pub slot_root: HashOutTarget,
pub slot_proofs: Vec<MerkleProofTarget>,
_phantom: PhantomData<(C,H)>, _phantom: PhantomData<(C,H)>,
} }
@ -227,38 +230,159 @@ impl<
H: Hasher<F> + AlgebraicHasher<F> + Hasher<F>, H: Hasher<F> + AlgebraicHasher<F> + Hasher<F>,
> DatasetTreeCircuit<F, C, D, H> { > DatasetTreeCircuit<F, C, D, H> {
// the in-circuit sampling of a slot in a dataset // in-circuit sampling
// TODO: make it more modular
pub fn sample_slot_circuit( pub fn sample_slot_circuit(
&mut self, &mut self,
builder: &mut CircuitBuilder::<F, D>, builder: &mut CircuitBuilder::<F, D>,
)-> DatasetTargets<F,C,D,H>{ )-> DatasetTargets<F,C,D,H>{
let (dataset_proof, dataset_root_target) = self.tree.build_circuit(builder); // constants
let one = builder.one();
let two = builder.two();
// ***** prove slot root is in dataset tree *********
// Retrieve dataset tree depth
let d_depth = DATASET_DEPTH;
// Create virtual target for slot root and index
let slot_root = builder.add_virtual_hash();
let slot_index = builder.add_virtual_target();
// dataset path bits (binary decomposition of leaf_index)
let d_path_bits = builder.split_le(slot_index,d_depth);
// dataset last bits (binary decomposition of last_index = nleaves - 1)
let depth_target = builder.constant(F::from_canonical_u64(d_depth as u64));
let mut d_last_index = builder.exp(two,depth_target,d_depth);
d_last_index = builder.sub(d_last_index, one);
let d_last_bits = builder.split_le(d_last_index,d_depth);
// dataset Merkle path (sibling hashes from leaf to root)
let d_merkle_path = MerkleProofTarget {
path: (0..d_depth).map(|_| builder.add_virtual_hash()).collect(),
};
// create MerkleTreeTargets struct
let mut d_targets = MerkleTreeTargets {
leaf: slot_root,
path_bits: d_path_bits,
last_bits: d_last_bits,
merkle_path: d_merkle_path,
_phantom: PhantomData,
};
// dataset reconstructed root
let d_reconstructed_root =
MerkleTreeCircuit::<F,C,D,H>::reconstruct_merkle_root_circuit(builder, &mut d_targets);
// expected Merkle root // expected Merkle root
let dataset_expected_root = builder.add_virtual_hash(); let d_expected_root = builder.add_virtual_hash();
// check equality with expected root // check equality with expected root
for i in 0..NUM_HASH_OUT_ELTS { for i in 0..NUM_HASH_OUT_ELTS {
builder.connect(dataset_expected_root.elements[i], dataset_root_target.elements[i]); builder.connect(d_expected_root.elements[i], d_reconstructed_root.elements[i]);
} }
let mut slot_proofs =vec![]; //*********** do the sampling ************
let mut data_targets =vec![];
let mut slot_sample_proofs = vec![];
let entropy_target = builder.add_virtual_target();
for i in 0..N_SAMPLES{ for i in 0..N_SAMPLES{
let proof_i = SlotTreeCircuit::<F,C,D,H>::prove_single_cell(builder); // cell data targets
slot_proofs.push(proof_i); let mut data_i = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::<Vec<_>>();
let mut perm_inputs:Vec<Target>= Vec::new();
perm_inputs.extend_from_slice(&data_i);
let data_i_hash = builder.hash_n_to_hash_no_pad::<H>(perm_inputs);
// counter constant
let ctr = builder.constant(F::from_canonical_u64(i as u64));
// paths
let mut b_path_bits = Self::calculate_cell_index_bits(builder, &entropy_target, &d_targets.leaf, &ctr);
let mut s_path_bits = b_path_bits.split_off(BOT_DEPTH);
//TODO: this can probably be optimized by supplying nCellsPerSlot as input to the circuit
let b_depth_target = builder.constant(F::from_canonical_u64(BOT_DEPTH as u64));
let mut b_last_index = builder.exp(two,b_depth_target,BOT_DEPTH);
b_last_index = builder.sub(b_last_index, one);
let b_last_bits = builder.split_le(b_last_index,BOT_DEPTH);
let s_depth_target = builder.constant(F::from_canonical_u64(MAX_DEPTH as u64));
let mut s_last_index = builder.exp(two,s_depth_target,MAX_DEPTH);
s_last_index = builder.sub(s_last_index, one);
let s_last_bits = builder.split_le(s_last_index,MAX_DEPTH);
let mut b_merkle_path = MerkleProofTarget {
path: (0..BOT_DEPTH).map(|_| builder.add_virtual_hash()).collect(),
};
let mut s_merkle_path = MerkleProofTarget {
path: (0..(MAX_DEPTH - BOT_DEPTH)).map(|_| builder.add_virtual_hash()).collect(),
};
let mut block_targets = MerkleTreeTargets {
leaf: data_i_hash,
path_bits:b_path_bits,
last_bits: b_last_bits,
merkle_path: b_merkle_path,
_phantom: PhantomData,
};
// reconstruct block root
let b_root = MerkleTreeCircuit::<F,C,D,H>::reconstruct_merkle_root_circuit(builder, &mut block_targets);
let mut slot_targets = MerkleTreeTargets {
leaf: b_root,
path_bits:s_path_bits,
last_bits:s_last_bits,
merkle_path:s_merkle_path,
_phantom: PhantomData,
};
// reconstruct slot root with block root as leaf
let slot_reconstructed_root = MerkleTreeCircuit::<F,C,D,H>::reconstruct_merkle_root_circuit(builder, &mut slot_targets);
// check equality with expected root
for i in 0..NUM_HASH_OUT_ELTS {
builder.connect( d_targets.leaf.elements[i], slot_reconstructed_root.elements[i]);
}
// combine block and slot path to get the full path so we can assign it later.
let mut slot_sample_proof_target = MerkleProofTarget{
path: block_targets.merkle_path.path,
};
slot_sample_proof_target.path.extend_from_slice(&slot_targets.merkle_path.path);
data_targets.push(data_i);
slot_sample_proofs.push(slot_sample_proof_target);
} }
DatasetTargets::<F,C,D,H>{ DatasetTargets::<F,C,D,H>{
dataset_proof, dataset_proof: d_targets.merkle_path,
dataset_root: dataset_expected_root, dataset_root: d_expected_root,
slot_proofs, cell_data: data_targets,
entropy: entropy_target,
slot_index,
slot_root: d_targets.leaf,
slot_proofs: slot_sample_proofs,
_phantom: Default::default(), _phantom: Default::default(),
} }
} }
// assign the witnesses to the targets pub fn calculate_cell_index_bits(builder: &mut CircuitBuilder::<F, D>, p0: &Target, p1: &HashOutTarget, p2: &Target) -> Vec<BoolTarget> {
// takes pw, the dataset targets, slot index, and entropy let mut perm_inputs:Vec<Target>= Vec::new();
perm_inputs.extend_from_slice(&p1.elements);
perm_inputs.push(*p0);
perm_inputs.push(*p2);
let data_i_hash = builder.hash_n_to_hash_no_pad::<H>(perm_inputs);
let p_bits = builder.low_bits(data_i_hash.elements[NUM_HASH_OUT_ELTS-1], MAX_DEPTH, 64);
p_bits
}
pub fn sample_slot_assign_witness( pub fn sample_slot_assign_witness(
&mut self, &mut self,
pw: &mut PartialWitness<F>, pw: &mut PartialWitness<F>,
@ -266,8 +390,19 @@ impl<
slot_index:usize, slot_index:usize,
entropy:usize, entropy:usize,
){ ){
// assign witness for dataset level target (proving slot root is in dataset tree) // dataset proof
self.tree.assign_witness(pw,&mut targets.dataset_proof,slot_index); let d_proof = self.tree.tree.get_proof(slot_index).unwrap();
// assign dataset proof
for (i, sibling_hash) in d_proof.path.iter().enumerate() {
// TODO: fix this HashOutTarget later
let sibling_hash_out = sibling_hash.to_vec();
for j in 0..sibling_hash_out.len() {
pw.set_target(targets.dataset_proof.path[i].elements[j], sibling_hash_out[j]);
}
}
// assign slot index
pw.set_target(targets.slot_index, F::from_canonical_u64(slot_index as u64));
// assign the expected Merkle root of dataset to the target // assign the expected Merkle root of dataset to the target
let expected_root = self.tree.tree.root().unwrap(); let expected_root = self.tree.tree.root().unwrap();
@ -279,14 +414,28 @@ impl<
// the sampled slot // the sampled slot
let slot = &self.slot_trees[slot_index]; let slot = &self.slot_trees[slot_index];
let slot_root = slot.tree.tree.root().unwrap(); let slot_root = slot.tree.tree.root().unwrap();
pw.set_hash_target(targets.slot_root, slot_root);
// assign entropy
pw.set_target(targets.entropy, F::from_canonical_u64(entropy as u64));
// do the sample N times // do the sample N times
for i in 0..N_SAMPLES { for i in 0..N_SAMPLES {
let cell_index_bits = calculate_cell_index_bits(entropy, slot_root, i); let cell_index_bits = calculate_cell_index_bits(entropy,slot_root,i);
let cell_index = bits_le_padded_to_usize(&cell_index_bits); let cell_index = bits_le_padded_to_usize(&cell_index_bits);
// assign cell data
let leaf = &slot.cell_data[cell_index]; let leaf = &slot.cell_data[cell_index];
let proof = slot.get_proof(cell_index); for j in 0..leaf.len(){
slot.single_cell_assign_witness(pw, &mut targets.slot_proofs[i],cell_index,leaf, proof.clone()); pw.set_target(targets.cell_data[i][j], leaf[j]);
}
// assign proof for that cell
let cell_proof = slot.get_proof(cell_index);
for (k, sibling_hash) in cell_proof.path.iter().enumerate() {
let sibling_hash_out = sibling_hash.to_vec();
for j in 0..sibling_hash_out.len() {
pw.set_target(targets.slot_proofs[i].path[k].elements[j], sibling_hash_out[j]);
}
}
} }
} }

View File

@ -1,4 +1,4 @@
use plonky2::hash::hash_types::{HashOut, RichField}; use plonky2::hash::hash_types::{HashOut, NUM_HASH_OUT_ELTS, RichField};
use plonky2::iop::witness::PartialWitness; use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_data::{CircuitData, VerifierCircuitData}; use plonky2::plonk::circuit_data::{CircuitData, VerifierCircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher}; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher};
@ -22,18 +22,19 @@ pub(crate) fn usize_to_bits_le_padded(index: usize, bit_length: usize) -> Vec<bo
} }
bits bits
} }
/// calculate the sampled cell index from entropy, slot root, and counter
pub(crate) fn calculate_cell_index_bits<F: RichField>(p0: usize, p1: HashOut<F>, p2: usize) -> Vec<bool> { pub(crate) fn calculate_cell_index_bits<F: RichField>(entropy: usize, slot_root: HashOut<F>, ctr: usize) -> Vec<bool> {
let p0_field = F::from_canonical_u64(p0 as u64); let p0_field = F::from_canonical_u64(entropy as u64);
let p2_field = F::from_canonical_u64(p2 as u64); let p2_field = F::from_canonical_u64(ctr as u64);
let mut inputs = Vec::new(); let mut inputs = Vec::new();
inputs.extend_from_slice(&p1.elements); inputs.extend_from_slice(&slot_root.elements);
inputs.push(p0_field); inputs.push(p0_field);
inputs.push(p2_field); inputs.push(p2_field);
let p_hash = HF::hash_no_pad(&inputs); let p_hash = HF::hash_no_pad(&inputs);
let p_bytes = p_hash.to_bytes(); let p_bytes = p_hash.elements[NUM_HASH_OUT_ELTS - 1].to_canonical_u64();
let p_bits = take_n_bits_from_bytes(&p_bytes, MAX_DEPTH); // let p_bits = take_n_bits_from_bytes(&p_bytes, MAX_DEPTH);
let p_bits = usize_to_bits_le_padded(p_bytes as usize, MAX_DEPTH);
p_bits p_bits
} }
pub(crate) fn take_n_bits_from_bytes(bytes: &[u8], n: usize) -> Vec<bool> { pub(crate) fn take_n_bits_from_bytes(bytes: &[u8], n: usize) -> Vec<bool> {