mirror of
https://github.com/logos-storage/proof-aggregation.git
synced 2026-01-02 13:53:13 +00:00
165 lines
6.6 KiB
Rust
165 lines
6.6 KiB
Rust
// some tests for approach 1 of the tree recursion
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::time::Instant;
|
|
use plonky2::plonk::config::{GenericConfig, GenericHashOut, Hasher};
|
|
use plonky2_field::types::Field;
|
|
use codex_plonky2_circuits::circuits::sample_cells::SampleCircuitInput;
|
|
use crate::params::{C, D, F, HF, Params};
|
|
use codex_plonky2_circuits::recursion::circuits::sampling_inner_circuit::SamplingRecursion;
|
|
use codex_plonky2_circuits::recursion::circuits::inner_circuit::InnerCircuit;
|
|
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
|
|
use crate::gen_input::get_m_default_circ_input;
|
|
use codex_plonky2_circuits::recursion::tree1::tree_circuit::TreeRecursion;
|
|
|
|
/// Uses node recursion to sample the dataset
|
|
#[test]
|
|
fn test_tree1_node_recursion() -> anyhow::Result<()> {
|
|
const M: usize = 1;
|
|
const N: usize = 2;
|
|
|
|
let default_params = Params::default().circuit_params;
|
|
let inner_sampling_circuit = SamplingRecursion::<F,D,HF,C>::new(default_params);
|
|
|
|
let circ_input = get_m_default_circ_input::<M>();
|
|
|
|
let s = Instant::now();
|
|
let mut tree_circ = TreeRecursion::<F,D,_,M,N,C>::build::<HF>(inner_sampling_circuit)?;
|
|
println!("build = {:?}", s.elapsed());
|
|
println!("tree circuit size = {:?}", tree_circ.node_circ.cyclic_circuit_data.common.degree_bits());
|
|
let s = Instant::now();
|
|
let proof = tree_circ.prove(&circ_input,None, true)?;
|
|
println!("prove = {:?}", s.elapsed());
|
|
println!("num of pi = {}", proof.public_inputs.len());
|
|
let s = Instant::now();
|
|
assert!(
|
|
tree_circ.verify_proof(proof).is_ok(),
|
|
"proof verification failed"
|
|
);
|
|
println!("verify = {:?}", s.elapsed());
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Uses node recursion to sample the dataset
|
|
#[test]
|
|
fn test_tree_recursion_approach1() -> anyhow::Result<()> {
|
|
const M: usize = 1;
|
|
const N: usize = 2;
|
|
|
|
const DEPTH: usize = 3;
|
|
const TOTAL_INPUT: usize = (N.pow(DEPTH as u32) - 1) / (N - 1);
|
|
|
|
let default_params = Params::default().circuit_params;
|
|
let inner_sampling_circuit = SamplingRecursion::<F,D,HF,C>::new(default_params);
|
|
|
|
let all_circ_input = get_m_default_circ_input::<TOTAL_INPUT>().to_vec();
|
|
|
|
let s = Instant::now();
|
|
let mut tree_circ = TreeRecursion::<F,D,_,M,N,C>::build::<HF>(inner_sampling_circuit)?;
|
|
println!("build = {:?}", s.elapsed());
|
|
println!("tree circuit size = {:?}", tree_circ.node_circ.cyclic_circuit_data.common.degree_bits());
|
|
let s = Instant::now();
|
|
let proof = tree_circ.prove_tree(all_circ_input.clone(),DEPTH)?;
|
|
println!("prove = {:?}", s.elapsed());
|
|
println!("num of pi = {}", proof.public_inputs.len());
|
|
|
|
// Extract the final public input hash from the proof
|
|
let final_proof_hash = &proof.public_inputs[0..4];
|
|
|
|
// Recompute the expected final public input hash (outside the circuit)
|
|
let expected_hash = compute_expected_pub_input_hash(
|
|
&all_circ_input,
|
|
DEPTH,
|
|
M,
|
|
N
|
|
)?;
|
|
|
|
// Check that the final hash in the proof matches the expected hash
|
|
assert_eq!(final_proof_hash, expected_hash.as_slice(), "Public input hash mismatch");
|
|
|
|
let s = Instant::now();
|
|
assert!(
|
|
tree_circ.verify_proof(proof).is_ok(),
|
|
"proof verification failed"
|
|
);
|
|
println!("verify = {:?}", s.elapsed());
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Recursively compute the final public input hash for a single node in the recursion tree.
|
|
/// This is the same logic from `NodeCircuit::build_circuit`
|
|
/// TODO: optimize this
|
|
fn compute_node_hash(
|
|
all_circ_inputs: &[SampleCircuitInput<F, D>],
|
|
depth: usize,
|
|
current_depth: usize,
|
|
node_idx: usize,
|
|
M: usize,
|
|
N: usize,
|
|
) -> [F; 4] {
|
|
// Calculate the index in all_circ_inputs for this node's M inputs.
|
|
// Total inputs per layer: sum_{k=0}^{current_depth-1} M*N^k = M * ((N^current_depth - 1)/(N-1))
|
|
let offset_for_layer = ((N.pow(current_depth as u32) - 1) / (N - 1)) * M;
|
|
let node_start = offset_for_layer + node_idx * M;
|
|
let node_inputs = &all_circ_inputs[node_start..node_start + M];
|
|
|
|
// Compute the outer public input hash:
|
|
// public inputs are [slot_index, dataset_root.elements, entropy.elements].
|
|
let mut outer_pi_hashes = vec![];
|
|
for inp in node_inputs {
|
|
let mut pi_vec = vec![inp.slot_index];
|
|
pi_vec.extend_from_slice(&inp.dataset_root.elements);
|
|
pi_vec.extend_from_slice(&inp.entropy.elements);
|
|
let hash_res = HF::hash_no_pad(&pi_vec);
|
|
outer_pi_hashes.extend_from_slice(&hash_res.elements);
|
|
}
|
|
|
|
// hash all these M hashes into one
|
|
let outer_pi_hash = HF::hash_no_pad(&outer_pi_hashes);
|
|
|
|
let is_leaf = current_depth == depth - 1;
|
|
|
|
// Compute the inner proof hash (or zero hash if leaf)
|
|
let inner_pi_hash_or_zero = if is_leaf {
|
|
// condition = false at leaf, so inner proofs = zero hash
|
|
[F::ZERO; 4]
|
|
} else {
|
|
// condition = true at non-leaf node -> recursively compute child hashes
|
|
let next_depth = current_depth + 1;
|
|
let child_start = node_idx * N;
|
|
|
|
let mut inner_pub_input_hashes = vec![];
|
|
for i in child_start..child_start + N {
|
|
let child_hash = compute_node_hash(all_circ_inputs, depth, next_depth, i, M, N);
|
|
inner_pub_input_hashes.extend_from_slice(&child_hash);
|
|
}
|
|
|
|
let inner_pub_input_hash = HF::hash_no_pad(&inner_pub_input_hashes);
|
|
inner_pub_input_hash.elements
|
|
};
|
|
|
|
// Combine outer_pi_hash and inner_pi_hash_or_zero
|
|
let mut final_input = vec![];
|
|
final_input.extend_from_slice(&outer_pi_hash.elements);
|
|
final_input.extend_from_slice(&inner_pi_hash_or_zero);
|
|
|
|
let final_hash = HF::hash_no_pad(&final_input);
|
|
final_hash.elements
|
|
}
|
|
|
|
/// Compute the expected public input hash for the entire recursion tree.
|
|
/// This function calls `compute_node_hash` starting from the root (layer 0, node 0).
|
|
pub fn compute_expected_pub_input_hash(
|
|
all_circ_inputs: &[SampleCircuitInput<F, D>],
|
|
depth: usize,
|
|
M: usize,
|
|
N: usize,
|
|
) -> anyhow::Result<Vec<F>> {
|
|
// The root node is at layer = 0 and node_idx = 0
|
|
let final_hash = compute_node_hash(all_circ_inputs, depth, 0, 0, M, N);
|
|
Ok(final_hash.to_vec())
|
|
}
|
|
} |