refactor circuits

This commit is contained in:
M Alghazwi 2024-10-17 21:38:14 +02:00
parent 5a13ac3650
commit 9eefa78c24
5 changed files with 149 additions and 100 deletions

View File

@ -15,7 +15,7 @@ use plonky2::hash::hash_types::RichField;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use std::marker::PhantomData;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use codex_plonky2_circuits::circuits::prove_single_cell::SlotTree;
use codex_plonky2_circuits::circuits::prove_single_cell::SlotTreeCircuit;
macro_rules! pretty_print {
($($arg:tt)*) => {
@ -28,7 +28,7 @@ macro_rules! pretty_print {
type HF = PoseidonHash;
fn prepare_data<F, H>(N: usize) -> Result<(
SlotTree<F, H>,
SlotTreeCircuit<F, H>,
Vec<usize>,
Vec<MerkleProof<F, H>>,
)>
@ -37,7 +37,7 @@ where
H: Hasher<F> + AlgebraicHasher<F> + Hasher<F>,
{
// Initialize the slot tree with default data
let slot_tree = SlotTree::<F, H>::default();
let slot_tree = SlotTreeCircuit::<F, H>::default();
// Select N leaf indices to prove
let leaf_indices: Vec<usize> = (0..N).collect();
@ -52,7 +52,7 @@ where
}
fn build_circuit<F, C, const D: usize, H>(
slot_tree: &SlotTree<F, H>,
slot_tree: &SlotTreeCircuit<F, H>,
leaf_indices: &[usize],
proofs: &[MerkleProof<F, H>],
) -> Result<(CircuitData<F, C, D>, PartialWitness<F>)>

View File

@ -35,14 +35,24 @@ use crate::circuits::params::{MAX_DEPTH, BOT_DEPTH, N_FIELD_ELEMS_PER_CELL, N_CE
// ------ Slot Tree --------
#[derive(Clone)]
pub struct SlotTree<F: RichField, H: Hasher<F>> {
pub tree: MerkleTree<F,H>, // slot tree
pub block_trees: Vec<MerkleTree<F,H>>, // vec of block trees
pub struct SlotTreeCircuit<
F: RichField + Extendable<D> + Poseidon2,
C: GenericConfig<D, F = F>,
const D: usize,
H: Hasher<F> + AlgebraicHasher<F>,
> {
pub tree: MerkleTreeCircuit<F,C,D,H>, // slot tree
pub block_trees: Vec<MerkleTreeCircuit<F,C,D,H>>, // vec of block trees
pub cell_data: Vec<Vec<F>>, // cell data as field elements
pub cell_hash: Vec<HashOut<F>>, // hash of above
}
impl<F: RichField, H: Hasher<F>> Default for SlotTree<F,H>{
impl<
F: RichField + Extendable<D> + Poseidon2,
C: GenericConfig<D, F = F>,
const D: usize,
H: Hasher<F> + AlgebraicHasher<F>,
> Default for SlotTreeCircuit<F,C,D,H>{
/// slot tree with fake data, for testing only
fn default() -> Self {
// generate fake cell data
@ -69,19 +79,38 @@ impl<F: RichField, H: Hasher<F>> Default for SlotTree<F,H>{
.map(|i| {
let start = i * N_CELLS_IN_BLOCKS;
let end = (i + 1) * N_CELLS_IN_BLOCKS;
Self::get_block_tree(&leaves[start..end].to_vec()) // use helper function
let b_tree = Self::get_block_tree(&leaves[start..end].to_vec()); // use helper function
MerkleTreeCircuit::<F,C,D,H>{ tree:b_tree, _phantom:Default::default()}
})
.collect::<Vec<_>>();
// get the roots or block trees
let block_roots = block_trees.iter()
.map(|t| {
t.root().unwrap()
t.tree.root().unwrap()
})
.collect::<Vec<_>>();
// create slot tree
let slot_tree = MerkleTree::<F, H>::new(&block_roots, zero).unwrap();
// let mt =
// MerkleTree::<F,H>{
// tree: slot_tree,
// block_trees,
// cell_data,
// cell_hash: leaves,
// }
// create block circuits
// let block_circuits = block_trees.iter()
// .map(|b_tree| {
// // let start = i * N_CELLS_IN_BLOCKS;
// // let end = (i + 1) * N_CELLS_IN_BLOCKS;
// // Self::get_block_tree(&leaves[start..end].to_vec()) // use helper function
// MerkleTreeCircuit::<F,C,D,H>{ tree:b_tree.clone(), _phantom:Default::default()},
// })
// .collect::<Vec<_>>();
Self{
tree: slot_tree,
tree: MerkleTreeCircuit::<F,C,D,H>{ tree:slot_tree, _phantom:Default::default()},
block_trees,
cell_data,
cell_hash: leaves,
@ -89,7 +118,12 @@ impl<F: RichField, H: Hasher<F>> Default for SlotTree<F,H>{
}
}
impl<F: RichField, H: Hasher<F>> SlotTree<F, H> {
impl<
F: RichField + Extendable<D> + Poseidon2,
C: GenericConfig<D, F = F>,
const D: usize,
H: Hasher<F> + AlgebraicHasher<F>,
> SlotTreeCircuit<F,C,D, H> {
/// same as default but with supplied cell data
pub fn new(cell_data: Vec<Vec<F>>) -> Self{
@ -106,17 +140,18 @@ impl<F: RichField, H: Hasher<F>> SlotTree<F, H> {
.map(|i| {
let start = i * N_CELLS_IN_BLOCKS;
let end = (i + 1) * N_CELLS_IN_BLOCKS;
Self::get_block_tree(&leaves[start..end].to_vec())
let b_tree = Self::get_block_tree(&leaves[start..end].to_vec());
MerkleTreeCircuit::<F,C,D,H>{ tree:b_tree, _phantom:Default::default()}
})
.collect::<Vec<_>>();
let block_roots = block_trees.iter()
.map(|t| {
t.root().unwrap()
t.tree.root().unwrap()
})
.collect::<Vec<_>>();
let slot_tree = MerkleTree::<F, H>::new(&block_roots, zero).unwrap();
Self{
tree: slot_tree,
tree: MerkleTreeCircuit::<F,C,D,H>{ tree:slot_tree, _phantom:Default::default()},
block_trees,
cell_data,
cell_hash: leaves,
@ -128,8 +163,8 @@ impl<F: RichField, H: Hasher<F>> SlotTree<F, H> {
pub fn get_proof(&self, index: usize) -> MerkleProof<F, H> {
let block_index = index/ N_CELLS_IN_BLOCKS;
let leaf_index = index % N_CELLS_IN_BLOCKS;
let block_proof = self.block_trees[block_index].get_proof(leaf_index).unwrap();
let slot_proof = self.tree.get_proof(block_index).unwrap();
let block_proof = self.block_trees[block_index].tree.get_proof(leaf_index).unwrap();
let slot_proof = self.tree.tree.get_proof(block_index).unwrap();
// Combine the paths from the block and slot proofs
let mut combined_path = block_proof.path.clone();
@ -213,10 +248,10 @@ impl<
C: GenericConfig<D, F=F>,
const D: usize,
H: Hasher<F> + AlgebraicHasher<F> + Hasher<F>,
> MerkleTreeCircuit<F, C, D, H> {
> SlotTreeCircuit<F, C, D, H> {
pub fn prove_single_cell2(
&mut self,
pub fn prove_single_cell(
// &mut self,
builder: &mut CircuitBuilder::<F, D>
) -> SingleCellTargets<F, C, D, H> {
@ -258,7 +293,7 @@ impl<
};
// reconstruct block root
let block_root = self.reconstruct_merkle_root_circuit(builder, &mut block_targets);
let block_root = MerkleTreeCircuit::<F,C,D,H>::reconstruct_merkle_root_circuit(builder, &mut block_targets);
// create MerkleTreeTargets struct
let mut slot_targets = MerkleTreeTargets {
@ -270,7 +305,7 @@ impl<
};
// reconstruct slot root with block root as leaf
let slot_root = self.reconstruct_merkle_root_circuit(builder, &mut slot_targets);
let slot_root = MerkleTreeCircuit::<F,C,D,H>::reconstruct_merkle_root_circuit(builder, &mut slot_targets);
// check equality with expected root
for i in 0..NUM_HASH_OUT_ELTS {
@ -305,7 +340,7 @@ impl<
/// this takes leaf_index, leaf, and proof (generated from slot_tree)
/// and fills all required circuit targets(circuit inputs)
pub fn single_cell_assign_witness(
&mut self,
&self,
pw: &mut PartialWitness<F>,
targets: &mut SingleCellTargets<F, C, D, H>,
leaf_index: usize,
@ -343,7 +378,7 @@ impl<
}
// assign the expected Merkle root to the target
let expected_root = self.tree.root()?;
let expected_root = self.tree.tree.root()?;
// TODO: fix this HashOutTarget later same issue as above
let expected_root_hash_out = expected_root.to_vec();
for j in 0..expected_root_hash_out.len() {
@ -367,49 +402,46 @@ mod tests {
use plonky2::iop::witness::PartialWitness;
//types for tests
type F = GoldilocksField;
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type H = PoseidonHash;
#[test]
fn test_prove_single_cell(){
let slot_t = SlotTree::<F,H>::default();
let slot_t = SlotTreeCircuit::<F,C,D,H>::default();
let index = 8;
let proof = slot_t.get_proof(index);
let res = slot_t.verify_cell_proof(proof,slot_t.tree.root().unwrap()).unwrap();
let res = slot_t.verify_cell_proof(proof,slot_t.tree.tree.root().unwrap()).unwrap();
assert_eq!(res, true);
}
#[test]
fn test_cell_build_circuit() -> Result<()> {
// circuit params
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type H = PoseidonHash;
let slot_t = SlotTree::<F,H>::default();
let slot_t = SlotTreeCircuit::<F,C,D,H>::default();
// select leaf index to prove
let leaf_index: usize = 8;
let proof = slot_t.get_proof(leaf_index);
// get the expected Merkle root
let expected_root = slot_t.tree.root().unwrap();
let expected_root = slot_t.tree.tree.root().unwrap();
let res = slot_t.verify_cell_proof(proof.clone(),expected_root).unwrap();
assert_eq!(res, true);
// create the circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut circuit_instance = MerkleTreeCircuit::<F, C, D, H> {
tree: slot_t.tree.clone(),
_phantom: PhantomData,
};
let mut targets = circuit_instance.prove_single_cell2(&mut builder);
// let mut circuit_instance = MerkleTreeCircuit::<F, C, D, H> {
// tree: slot_t.tree.clone(),
// _phantom: PhantomData,
// };
let mut targets = SlotTreeCircuit::<F,C,D,H>::prove_single_cell(&mut builder);
// create a PartialWitness and assign
let mut pw = PartialWitness::new();
circuit_instance.single_cell_assign_witness(&mut pw, &mut targets, leaf_index, &slot_t.cell_data[leaf_index], proof)?;
slot_t.single_cell_assign_witness(&mut pw, &mut targets, leaf_index, &slot_t.cell_data[leaf_index], proof)?;
// build the circuit
let data = builder.build::<C>();

View File

@ -28,7 +28,7 @@ use crate::merkle_tree::merkle_safe::{KEY_NONE,KEY_BOTTOM_LAYER};
// note: this omits the mask bits since in plonky2 we can
// uses the Plonk's permutation argument to check that two elements are equal.
// TODO: double check the need for mask
// #[derive(Clone)]
#[derive(Clone)]
pub struct MerkleTreeTargets<
F: RichField + Extendable<D> + Poseidon2,
C: GenericConfig<D, F = F>,
@ -44,7 +44,7 @@ pub struct MerkleTreeTargets<
/// Merkle tree circuit contains the tree and functions for
/// building, proving and verifying the circuit.
// #[derive(Clone)]
#[derive(Clone)]
pub struct MerkleTreeCircuit<
F: RichField + Extendable<D> + Poseidon2,
C: GenericConfig<D, F = F>,
@ -94,7 +94,7 @@ impl<
};
// Add Merkle proof verification constraints to the circuit
self.reconstruct_merkle_root_circuit(builder, &mut targets);
Self::reconstruct_merkle_root_circuit(builder, &mut targets);
// Return MerkleTreeTargets
targets
@ -174,7 +174,7 @@ impl<
/// takes the params from the targets struct
/// outputs the reconstructed merkle root
pub fn reconstruct_merkle_root_circuit(
&self,
// &self,
builder: &mut CircuitBuilder<F, D>,
targets: &mut MerkleTreeTargets<F, C, D, H>,
) -> HashOutTarget {

View File

@ -27,17 +27,23 @@ use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::plonk::config::PoseidonGoldilocksConfig;
use plonky2::hash::hashing::PlonkyPermutation;
use crate::circuits::prove_single_cell::{SingleCellTargets, SlotTree};
use crate::circuits::prove_single_cell::{SingleCellTargets, SlotTreeCircuit};
use crate::circuits::params::{MAX_DEPTH, BOT_DEPTH, N_FIELD_ELEMS_PER_CELL, N_CELLS_IN_BLOCKS, N_BLOCKS, N_CELLS, HF, DATASET_DEPTH, N_SAMPLES};
use crate::circuits::safe_tree_circuit::{MerkleTreeCircuit, MerkleTreeTargets};
use crate::circuits::utils::{bits_le_padded_to_usize, calculate_cell_index_bits};
// ------ Dataset Tree --------
///dataset tree containing all slot trees
#[derive(Clone)]
pub struct DatasetTree<F: RichField, H: Hasher<F>> {
pub tree: MerkleTree<F,H>, // dataset tree
pub slot_trees: Vec<SlotTree<F,H>>, // vec of slot trees
pub struct DatasetTreeCircuit<
F: RichField + Extendable<D> + Poseidon2,
C: GenericConfig<D, F = F>,
const D: usize,
H: Hasher<F> + AlgebraicHasher<F>,
> {
pub tree: MerkleTreeCircuit<F, C, D, H>, // dataset tree
pub slot_trees: Vec<SlotTreeCircuit<F,C,D,H>>, // vec of slot trees
}
/// Dataset Merkle proof struct, containing the dataset proof and N_SAMPLES proofs.
@ -49,18 +55,23 @@ pub struct DatasetMerkleProof<F: RichField, H: Hasher<F>> {
pub slot_proofs: Vec<MerkleProof<F,H>>, // proofs for sampled slot, contains N_SAMPLES proofs
}
impl<F: RichField, H: Hasher<F>> Default for DatasetTree<F,H> {
impl<
F: RichField + Extendable<D> + Poseidon2,
C: GenericConfig<D, F = F>,
const D: usize,
H: Hasher<F> + AlgebraicHasher<F>,
> Default for DatasetTreeCircuit<F,C,D,H> {
/// dataset tree with fake data, for testing only
fn default() -> Self {
let mut slot_trees = vec![];
let n_slots = 1<<DATASET_DEPTH;
for i in 0..n_slots {
slot_trees.push(SlotTree::<F,H>::default());
slot_trees.push(SlotTreeCircuit::<F,C,D,H>::default());
}
// get the roots or slot trees
let slot_roots = slot_trees.iter()
.map(|t| {
t.tree.root().unwrap()
t.tree.tree.root().unwrap()
})
.collect::<Vec<_>>();
// zero hash
@ -69,20 +80,25 @@ impl<F: RichField, H: Hasher<F>> Default for DatasetTree<F,H> {
};
let dataset_tree = MerkleTree::<F, H>::new(&slot_roots, zero).unwrap();
Self{
tree: dataset_tree,
tree: MerkleTreeCircuit::<F,C,D,H>{ tree:dataset_tree, _phantom:Default::default()},
slot_trees,
}
}
}
impl<F: RichField, H: Hasher<F>> DatasetTree<F, H> {
impl<
F: RichField + Extendable<D> + Poseidon2,
C: GenericConfig<D, F = F>,
const D: usize,
H: Hasher<F> + AlgebraicHasher<F>,
> DatasetTreeCircuit<F,C,D,H> {
/// same as default but with supplied slot trees
pub fn new(slot_trees: Vec<SlotTree<F,H>>) -> Self{
pub fn new(slot_trees: Vec<SlotTreeCircuit<F,C,D,H>>) -> Self{
// get the roots or slot trees
let slot_roots = slot_trees.iter()
.map(|t| {
t.tree.root().unwrap()
t.tree.tree.root().unwrap()
})
.collect::<Vec<_>>();
// zero hash
@ -91,7 +107,7 @@ impl<F: RichField, H: Hasher<F>> DatasetTree<F, H> {
};
let dataset_tree = MerkleTree::<F, H>::new(&slot_roots, zero).unwrap();
Self{
tree: dataset_tree,
tree: MerkleTreeCircuit::<F,C,D,H>{ tree:dataset_tree, _phantom:Default::default()},
slot_trees,
}
}
@ -99,16 +115,16 @@ impl<F: RichField, H: Hasher<F>> DatasetTree<F, H> {
/// generates a dataset level proof for given slot index
/// just a regular merkle tree proof
pub fn get_proof(&self, index: usize) -> MerkleProof<F, H> {
let dataset_proof = self.tree.get_proof(index).unwrap();
let dataset_proof = self.tree.tree.get_proof(index).unwrap();
dataset_proof
}
/// generates a proof for given slot index
/// also takes entropy so it can use it sample the slot
pub fn sample_slot(&self, index: usize, entropy: usize) -> DatasetMerkleProof<F, H> {
let dataset_proof = self.get_proof(index);
let dataset_proof = self.tree.tree.get_proof(index).unwrap();
let slot = &self.slot_trees[index];
let slot_root = slot.tree.root().unwrap();
let slot_root = slot.tree.tree.root().unwrap();
let mut slot_proofs = vec![];
// get the index for cell from H(slot_root|counter|entropy)
for i in 0..N_SAMPLES {
@ -128,9 +144,9 @@ impl<F: RichField, H: Hasher<F>> DatasetTree<F, H> {
// verify the sampling - non-circuit version
pub fn verify_sampling(&self, proof: DatasetMerkleProof<F,H>) -> Result<bool>{
let slot = &self.slot_trees[proof.slot_index];
let slot_root = slot.tree.root().unwrap();
let slot_root = slot.tree.tree.root().unwrap();
// check dataset level proof
let d_res = proof.dataset_proof.verify(slot_root,self.tree.root().unwrap());
let d_res = proof.dataset_proof.verify(slot_root,self.tree.tree.root().unwrap());
if(d_res.unwrap() == false){
return Ok(false);
}
@ -180,7 +196,7 @@ impl<
C: GenericConfig<D, F=F>,
const D: usize,
H: Hasher<F> + AlgebraicHasher<F> + Hasher<F>,
> MerkleTreeCircuit<F, C, D, H> {
> DatasetTreeCircuit<F, C, D, H> {
// the in-circuit sampling of a slot in a dataset
pub fn sample_slot_circuit(
@ -192,7 +208,7 @@ impl<
// let slot_root = builder.add_virtual_hash();
let mut slot_proofs =vec![];
for i in 0..N_SAMPLES{
let proof_i = self.prove_single_cell2(builder);
let proof_i = SlotTreeCircuit::<F,C,D,H>::prove_single_cell(builder);
slot_proofs.push(proof_i);
}
@ -212,7 +228,7 @@ impl<
&mut self,
pw: &mut PartialWitness<F>,
targets: DatasetTargets<F,C,D,H>,
dataset_tree: DatasetTree<F,H>,
dataset_tree: DatasetTreeCircuit<F,C,D,H>,
slot_index:usize,
entropy:usize,
){
@ -222,38 +238,6 @@ impl<
}
// --------- helper functions --------------
fn calculate_cell_index_bits<F: RichField>(p0: usize, p1: HashOut<F>, p2: usize) -> Vec<bool> {
let p0_field = F::from_canonical_u64(p0 as u64);
let p2_field = F::from_canonical_u64(p2 as u64);
let mut inputs = Vec::new();
inputs.extend_from_slice(&p1.elements);
inputs.push(p0_field);
inputs.push(p2_field);
let p_hash = HF::hash_no_pad(&inputs);
let p_bytes = p_hash.to_bytes();
let p_bits = take_n_bits_from_bytes(&p_bytes, MAX_DEPTH);
p_bits
}
fn take_n_bits_from_bytes(bytes: &[u8], n: usize) -> Vec<bool> {
bytes.iter()
.flat_map(|byte| (0..8u8).map(move |i| (byte >> i) & 1 == 1))
.take(n)
.collect()
}
/// Converts a vector of bits (LSB first) into an index (usize).
fn bits_le_padded_to_usize(bits: &[bool]) -> usize {
bits.iter().enumerate().fold(0usize, |acc, (i, &bit)| {
if bit {
acc | (1 << i)
} else {
acc
}
})
}
#[cfg(test)]
mod tests {
use std::time::Instant;
@ -263,12 +247,14 @@ mod tests {
use plonky2::iop::witness::PartialWitness;
//types for tests
type F = GoldilocksField;
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type H = PoseidonHash;
#[test]
fn test_sample_cells() {
let dataset_t = DatasetTree::<F, H>::default();
let dataset_t = DatasetTreeCircuit::<F,C,D,H>::default();
let slot_index = 2;
let entropy = 123;
let proof = dataset_t.sample_slot(slot_index,entropy);

View File

@ -1,5 +1,6 @@
use plonky2::hash::hash_types::{HashOut, RichField};
use plonky2::plonk::config::{GenericHashOut, Hasher};
use crate::circuits::params::{HF, MAX_DEPTH};
// --------- helper functions ---------
@ -14,4 +15,34 @@ pub(crate) fn usize_to_bits_le_padded(index: usize, bit_length: usize) -> Vec<bo
bits.push(false);
}
bits
}
pub(crate) fn calculate_cell_index_bits<F: RichField>(p0: usize, p1: HashOut<F>, p2: usize) -> Vec<bool> {
let p0_field = F::from_canonical_u64(p0 as u64);
let p2_field = F::from_canonical_u64(p2 as u64);
let mut inputs = Vec::new();
inputs.extend_from_slice(&p1.elements);
inputs.push(p0_field);
inputs.push(p2_field);
let p_hash = HF::hash_no_pad(&inputs);
let p_bytes = p_hash.to_bytes();
let p_bits = take_n_bits_from_bytes(&p_bytes, MAX_DEPTH);
p_bits
}
pub(crate) fn take_n_bits_from_bytes(bytes: &[u8], n: usize) -> Vec<bool> {
bytes.iter()
.flat_map(|byte| (0..8u8).map(move |i| (byte >> i) & 1 == 1))
.take(n)
.collect()
}
/// Converts a vector of bits (LSB first) into an index (usize).
pub(crate) fn bits_le_padded_to_usize(bits: &[bool]) -> usize {
bits.iter().enumerate().fold(0usize, |acc, (i, &bit)| {
if bit {
acc | (1 << i)
} else {
acc
}
})
}