This commit is contained in:
wborgeaud 2022-10-26 10:58:38 +02:00
parent 7afbddb0b6
commit 341e1ebeec

View File

@ -10,7 +10,7 @@ use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::{PartialWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::{
CircuitData, CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData,
CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData,
};
use crate::plonk::config::Hasher;
use crate::plonk::config::{AlgebraicHasher, GenericConfig};
@ -43,7 +43,6 @@ pub struct CyclicPublicInputs<
> {
pub circuit_digest: HashOut<F>,
pub constants_sigmas_cap: MerkleCap<F, C::Hasher>,
pub base_case: bool,
}
impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
@ -53,30 +52,23 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
where
C::Hasher: AlgebraicHasher<F>,
{
// The structure of the public inputs is `[...,circuit_digest, constants_sigmas_cap, base_case]`.
// The structure of the public inputs is `[..., circuit_digest, constants_sigmas_cap]`.
let cap_len = common_data.config.fri_config.num_cap_elements();
let len = slice.len();
ensure!(len >= 4 + 4 * cap_len + 1, "Not enough public inputs");
let base_case = slice[len - 1];
ensure!(
base_case.is_one() || base_case.is_zero(),
"Base case flag {:?} is not binary",
base_case
);
ensure!(len >= 4 + 4 * cap_len, "Not enough public inputs");
let constants_sigmas_cap = MerkleCap(
(0..cap_len)
.map(|i| HashOut {
elements: std::array::from_fn(|j| slice[len - 1 - 4 * (cap_len - i) + j]),
elements: std::array::from_fn(|j| slice[len - 4 * (cap_len - i) + j]),
})
.collect(),
);
let circuit_digest =
HashOut::from_partial(&slice[len - 5 - 4 * cap_len..len - 1 - 4 * cap_len]);
HashOut::from_partial(&slice[len - 4 - 4 * cap_len..len - 4 * cap_len]);
Ok(Self {
circuit_digest,
constants_sigmas_cap,
base_case: base_case.is_one(),
})
}
}
@ -84,7 +76,6 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
pub struct CyclicPublicInputsTarget {
pub circuit_digest: HashOutTarget,
pub constants_sigmas_cap: MerkleCapTarget,
pub base_case: BoolTarget,
}
impl CyclicPublicInputsTarget {
@ -94,34 +85,31 @@ impl CyclicPublicInputsTarget {
) -> Result<Self> {
let cap_len = common_data.config.fri_config.num_cap_elements();
let len = slice.len();
ensure!(len >= 4 + 4 * cap_len + 1, "Not enough public inputs");
let base_case = BoolTarget::new_unsafe(slice[len - 1]);
ensure!(len >= 4 + 4 * cap_len, "Not enough public inputs");
let constants_sigmas_cap = MerkleCapTarget(
(0..cap_len)
.map(|i| HashOutTarget {
elements: std::array::from_fn(|j| slice[len - 1 - 4 * (cap_len - i) + j]),
elements: std::array::from_fn(|j| slice[len - 4 * (cap_len - i) + j]),
})
.collect(),
);
let circuit_digest = HashOutTarget {
elements: std::array::from_fn(|i| slice[len - 5 - 4 * cap_len + i]),
elements: std::array::from_fn(|i| slice[len - 4 - 4 * cap_len + i]),
};
Ok(Self {
circuit_digest,
constants_sigmas_cap,
base_case,
})
}
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn cyclic_recursion<C: GenericConfig<D, F = F>>(
mut self,
&mut self,
previous_virtual_public_inputs: &[Target],
previous_base_case: Target,
mut common_data: CommonCircuitData<F, D>,
) -> Result<(CircuitData<F, C, D>, CyclicRecursionTarget<D>)>
common_data: &mut CommonCircuitData<F, D>,
) -> Result<CyclicRecursionTarget<D>>
where
C::Hasher: AlgebraicHasher<F>,
[(); C::Hasher::HASH_SIZE]:,
@ -147,22 +135,15 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
};
// Flag set to true for the base case of the cycle where we verify a dummy proof to bootstrap the cycle. Set to false otherwise.
// Unsafe is ok since `base_case` is a public input and its booleaness should be checked in the verifier.
let base_case = self.add_virtual_bool_target_unsafe();
self.register_public_input(base_case.target);
let base_case = self.add_virtual_bool_target_safe();
common_data.num_public_inputs = self.num_public_inputs();
let proof = self.add_virtual_proof_with_pis::<C>(&common_data);
let dummy_proof = self.add_virtual_proof_with_pis::<C>(&common_data);
let proof = self.add_virtual_proof_with_pis::<C>(common_data);
let dummy_proof = self.add_virtual_proof_with_pis::<C>(common_data);
let pis =
CyclicPublicInputsTarget::from_slice::<F, C, D>(&proof.public_inputs, &common_data)?;
// Check that the previous base case flag was boolean.
self.assert_bool(pis.base_case);
// Check that we cannot go from a non-base case to a base case by checking `previous_base_case - base_case \in {0,1}`.
let decrease = BoolTarget::new_unsafe(self.sub(pis.base_case.target, base_case.target));
self.assert_bool(decrease);
CyclicPublicInputsTarget::from_slice::<F, C, D>(&proof.public_inputs, common_data)?;
// Connect previous verifier data to current one. This guarantees that every proof in the cycle uses the same verifier data.
self.connect_hashes(pis.circuit_digest, verifier_data.circuit_digest);
for (h0, h1) in pis
@ -174,7 +155,6 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.connect_hashes(*h0, *h1);
}
self.connect(previous_base_case, pis.base_case.target);
for (x, y) in previous_virtual_public_inputs
.iter()
.zip(&proof.public_inputs)
@ -189,7 +169,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
&dummy_verifier_data,
&proof,
&verifier_data,
&common_data,
common_data,
);
// Make sure we have enough gates to match `common_data`.
@ -201,24 +181,13 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.add_gate_to_gate_set(g.clone());
}
let data = self.build::<C>();
ensure!(
data.common == common_data,
"Common data does not match. Final circuit has common data {:?} instead of {:?}.",
data.common,
common_data
);
Ok((
data,
CyclicRecursionTarget {
proof,
verifier_data,
dummy_proof,
dummy_verifier_data,
base_case,
},
))
Ok(CyclicRecursionTarget {
proof,
verifier_data,
dummy_proof,
dummy_verifier_data,
base_case,
})
}
}
@ -256,8 +225,6 @@ where
let mut proof = dummy_proof.clone();
proof.public_inputs[0..public_inputs.len()].copy_from_slice(public_inputs);
let pis_len = proof.public_inputs.len();
// A base case must be following another base case.
proof.public_inputs[pis_len - 1] = F::ONE;
// The circuit checks that the verifier data is the same throughout the cycle, so
// we set the verifier data to the "real" verifier data even though it's unused in the base case.
let num_cap = cyclic_recursion_data
@ -265,7 +232,7 @@ where
.config
.fri_config
.num_cap_elements();
let s = pis_len - 5 - 4 * num_cap;
let s = pis_len - 4 - 4 * num_cap;
proof.public_inputs[s..s + 4]
.copy_from_slice(&cyclic_recursion_data.verifier_data.circuit_digest.elements);
for i in 0..num_cap {
@ -305,10 +272,8 @@ where
C::Hasher: AlgebraicHasher<F>,
{
let pis = CyclicPublicInputs::<F, C, D>::from_slice(&proof.public_inputs, common_data)?;
if !pis.base_case {
ensure!(verifier_data.constants_sigmas_cap == pis.constants_sigmas_cap);
ensure!(verifier_data.circuit_digest == pis.circuit_digest);
}
ensure!(verifier_data.constants_sigmas_cap == pis.constants_sigmas_cap);
ensure!(verifier_data.circuit_digest == pis.circuit_digest);
Ok(())
}
@ -325,7 +290,6 @@ mod tests {
use crate::hash::hash_types::RichField;
use crate::hash::hashing::hash_n_to_hash_no_pad;
use crate::hash::poseidon::{PoseidonHash, PoseidonPermutation};
use crate::iop::target::BoolTarget;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarget};
@ -386,24 +350,15 @@ mod tests {
builder.register_public_inputs(&initial_hash.elements);
// Hash from the previous proof.
let old_hash = builder.add_virtual_hash();
// Flag set to true if the last proof was a base case.
let old_base_case = builder.add_virtual_target();
// The input hash is either the previous hash or the initial hash depending on whether
// the last proof was a base case.
let input_hash = builder.select_hash(
BoolTarget::new_unsafe(old_base_case),
initial_hash,
old_hash,
);
let input_hash = builder.add_virtual_hash();
let h = builder.hash_n_to_hash_no_pad::<PoseidonHash>(input_hash.elements.to_vec());
builder.register_public_inputs(&h.elements);
// Previous counter.
let old_counter = builder.add_virtual_target();
let one = builder.one();
let old_not_base_case = builder.sub(one, old_base_case);
// New counter is the previous counter +1 if the previous proof wasn't a base case.
let new_counter = builder.add(old_counter, old_not_base_case);
builder.register_public_input(new_counter);
let new_counter = builder.add_virtual_public_input();
let old_pis = [
initial_hash.elements.as_slice(),
old_hash.elements.as_slice(),
@ -411,11 +366,19 @@ mod tests {
]
.concat();
let common_data = common_data_for_recursion::<F, C, D>();
let mut common_data = common_data_for_recursion::<F, C, D>();
// Add cyclic recursion gadget.
let (cyclic_circuit_data, cyclic_data_target) =
builder.cyclic_recursion::<C>(&old_pis, old_base_case, common_data)?;
let cyclic_data_target = builder.cyclic_recursion::<C>(&old_pis, &mut common_data)?;
let input_hash_bis =
builder.select_hash(cyclic_data_target.base_case, initial_hash, old_hash);
builder.connect_hashes(input_hash, input_hash_bis);
let not_base_case = builder.sub(one, cyclic_data_target.base_case.target);
// New counter is the previous counter +1 if the previous proof wasn't a base case.
let new_counter_bis = builder.add(old_counter, not_base_case);
builder.connect(new_counter, new_counter_bis);
let cyclic_circuit_data = builder.build::<C>();
let cyclic_recursion_data = CyclicRecursionData {
proof: &None, // Base case: We don't have a proof to put here yet.
@ -482,6 +445,7 @@ mod tests {
let initial_hash = &proof.public_inputs[..4];
let hash = &proof.public_inputs[4..8];
let counter = proof.public_inputs[8];
dbg!(counter);
let mut h: [F; 4] = initial_hash.try_into().unwrap();
assert_eq!(
hash,