Better test

This commit is contained in:
wborgeaud 2022-10-17 14:56:16 +02:00
parent 366567935c
commit 09cee22d1f
2 changed files with 91 additions and 14 deletions

View File

@ -182,7 +182,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.collect()
}
fn select_hash(
pub(crate) fn select_hash(
&mut self,
b: BoolTarget,
h0: HashOutTarget,

View File

@ -118,12 +118,19 @@ impl CyclicPublicInputsTarget {
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn cyclic_recursion<C: GenericConfig<D, F = F>>(
mut self,
previous_virtual_public_inputs: &[Target],
previous_base_case: Target,
mut common_data: CommonCircuitData<F, C, D>,
) -> Result<(CircuitData<F, C, D>, CyclicRecursionTarget<D>)>
where
C::Hasher: AlgebraicHasher<F>,
[(); C::Hasher::HASH_SIZE]:,
{
ensure!(
previous_virtual_public_inputs.len() == self.num_public_inputs(),
"Incorrect number of public inputs."
);
let verifier_data = VerifierCircuitTarget {
constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height),
circuit_digest: self.add_virtual_hash(),
@ -168,6 +175,14 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.connect_hashes(*h0, *h1);
}
self.connect(previous_base_case, pis.base_case.target);
for (x, y) in previous_virtual_public_inputs
.iter()
.zip(&proof.public_inputs)
{
self.connect(*x, *y);
}
// Verify the dummy proof if `base_case` is set to true, otherwise verify the "real" proof.
self.conditionally_verify_proof(
base_case,
@ -212,6 +227,8 @@ pub fn set_cyclic_recursion_data_target<
pw: &mut PartialWitness<F>,
cyclic_recursion_data_target: &CyclicRecursionTarget<D>,
cyclic_recursion_data: &CyclicRecursionData<F, C, D>,
// Public inputs to set in the base case to seed some initial data.
public_inputs: &[F],
) -> Result<()>
where
C::Hasher: AlgebraicHasher<F>,
@ -233,6 +250,7 @@ where
let (dummy_proof, dummy_data) = dummy_proof(cyclic_recursion_data.common_data)?;
pw.set_bool_target(cyclic_recursion_data_target.base_case, true);
let mut proof = dummy_proof.clone();
proof.public_inputs[0..public_inputs.len()].copy_from_slice(public_inputs);
let pis_len = proof.public_inputs.len();
// A base case must be following another base case.
proof.public_inputs[pis_len - 1] = F::ONE;
@ -293,13 +311,17 @@ where
#[cfg(test)]
mod tests {
use anyhow::Result;
use plonky2_field::extension::Extendable;
use plonky2_field::types::PrimeField64;
use crate::field::types::Field;
use crate::hash::hash_types::RichField;
use crate::hash::poseidon::PoseidonHash;
use crate::iop::witness::{PartialWitness, Witness};
use crate::hash::hashing::hash_n_to_hash_no_pad;
use crate::hash::poseidon::{PoseidonHash, PoseidonPermutation};
use crate::iop::target::BoolTarget;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarget};
use crate::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, PoseidonGoldilocksConfig};
@ -351,23 +373,54 @@ mod tests {
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
// Build realistic circuit
let t = builder.add_virtual_target();
pw.set_target(t, F::rand());
let t_inv = builder.inverse(t);
let h = builder.hash_n_to_hash_no_pad::<PoseidonHash>(vec![t_inv]);
// Circuit that computes a repeated hash.
let initial_hash = builder.add_virtual_hash();
builder.register_public_inputs(&initial_hash.elements);
// Hash from the previous proof.
let old_hash = builder.add_virtual_hash();
// Flag set to true if the last proof was a base case.
let old_base_case = builder.add_virtual_target();
// The input hash is either the previous hash or the initial hash depending on whether
// the last proof was a base case.
let input_hash = builder.select_hash(
BoolTarget::new_unsafe(old_base_case),
initial_hash,
old_hash,
);
let h = builder.hash_n_to_hash_no_pad::<PoseidonHash>(input_hash.elements.to_vec());
builder.register_public_inputs(&h.elements);
// Previous counter.
let old_counter = builder.add_virtual_target();
let one = builder.one();
let old_not_base_case = builder.sub(one, old_base_case);
// New counter is the previous counter +1 if the previous proof wasn't a base case.
let new_counter = builder.add(old_counter, old_not_base_case);
builder.register_public_input(new_counter);
let old_pis = [
initial_hash.elements.as_slice(),
old_hash.elements.as_slice(),
[old_counter].as_slice(),
]
.concat();
let common_data = common_data_for_recursion::<F, C, D>();
// Add cyclic recursion gadget.
let (cyclic_circuit_data, cyclic_data_target) = builder.cyclic_recursion(common_data)?;
let (cyclic_circuit_data, cyclic_data_target) =
builder.cyclic_recursion(&old_pis, old_base_case, common_data)?;
let cyclic_recursion_data = CyclicRecursionData {
proof: &None, // Base case: We don't have a proof to put here yet.
verifier_data: &cyclic_circuit_data.verifier_only,
common_data: &cyclic_circuit_data.common,
};
set_cyclic_recursion_data_target(&mut pw, &cyclic_data_target, &cyclic_recursion_data)?;
let initial_hash = [F::ZERO, F::ONE, F::TWO, F::from_canonical_usize(3)];
set_cyclic_recursion_data_target(
&mut pw,
&cyclic_data_target,
&cyclic_recursion_data,
&initial_hash,
)?;
let proof = cyclic_circuit_data.prove(pw)?;
check_cyclic_proof_verifier_data(
&proof,
@ -378,13 +431,17 @@ mod tests {
// 1st recursive layer.
let mut pw = PartialWitness::new();
pw.set_target(t, F::rand());
let cyclic_recursion_data = CyclicRecursionData {
proof: &Some(proof), // Input previous proof.
verifier_data: &cyclic_circuit_data.verifier_only,
common_data: &cyclic_circuit_data.common,
};
set_cyclic_recursion_data_target(&mut pw, &cyclic_data_target, &cyclic_recursion_data)?;
set_cyclic_recursion_data_target(
&mut pw,
&cyclic_data_target,
&cyclic_recursion_data,
&[],
)?;
let proof = cyclic_circuit_data.prove(pw)?;
check_cyclic_proof_verifier_data(
&proof,
@ -395,19 +452,39 @@ mod tests {
// 2nd recursive layer.
let mut pw = PartialWitness::new();
pw.set_target(t, F::rand());
let cyclic_recursion_data = CyclicRecursionData {
proof: &Some(proof), // Input previous proof.
verifier_data: &cyclic_circuit_data.verifier_only,
common_data: &cyclic_circuit_data.common,
};
set_cyclic_recursion_data_target(&mut pw, &cyclic_data_target, &cyclic_recursion_data)?;
set_cyclic_recursion_data_target(
&mut pw,
&cyclic_data_target,
&cyclic_recursion_data,
&[],
)?;
let proof = cyclic_circuit_data.prove(pw)?;
check_cyclic_proof_verifier_data(
&proof,
cyclic_recursion_data.verifier_data,
cyclic_recursion_data.common_data,
)?;
// Verify that the proof correctly computes a repeated hash.
let initial_hash = &proof.public_inputs[..4];
let hash = &proof.public_inputs[4..8];
let counter = proof.public_inputs[8];
let mut h: [F; 4] = initial_hash.try_into().unwrap();
assert_eq!(
hash,
std::iter::repeat_with(|| {
h = hash_n_to_hash_no_pad::<F, PoseidonPermutation>(&h).elements;
h
})
.nth(counter.to_canonical_u64() as usize)
.unwrap()
);
cyclic_circuit_data.verify(proof)
}
}