Add aggregation circuit

Which can be used to compress two proofs into one. Each inner proof can be either
- an "EVM root" proof (which typically proves one transaction, though it could be 0 or more)
- another aggregation proof
This commit is contained in:
Daniel Lubarov 2023-01-03 15:46:59 -08:00
parent 0ca308400a
commit 5df784416a
2 changed files with 134 additions and 7 deletions

View File

@ -3,16 +3,19 @@ use std::ops::Range;
use itertools::Itertools;
use plonky2::field::extension::Extendable;
use plonky2::gates::noop::NoopGate;
use plonky2::hash::hash_types::RichField;
use plonky2::hash::hashing::SPONGE_WIDTH;
use plonky2::iop::challenger::RecursiveChallenger;
use plonky2::iop::target::Target;
use plonky2::iop::target::{BoolTarget, Target};
use plonky2::iop::witness::{PartialWitness, WitnessWrite};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData};
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitTarget};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher};
use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget};
use plonky2::recursion::cyclic_recursion::check_cyclic_proof_verifier_data;
use plonky2::util::timing::TimingTree;
use plonky2_util::log2_ceil;
use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES};
use crate::config::StarkConfig;
@ -44,24 +47,47 @@ where
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
{
/// The root circuit, which aggregates the (shrunk) per-table recursive proofs.
/// The EVM root circuit, which aggregates the (shrunk) per-table recursive proofs.
pub root: RootCircuitData<F, C, D>,
pub aggregation: AggregationCircuitData<F, C, D>,
/// Holds chains of circuits for each table and for each initial `degree_bits`.
by_table: [RecursiveCircuitsForTable<F, C, D>; NUM_TABLES],
}
/// Data for the special root circuit, which is used to combine each STARK's shrunk wrapper proof
/// Data for the EVM root circuit, which is used to combine each STARK's shrunk wrapper proof
/// into a single proof.
pub struct RootCircuitData<F, C, const D: usize>
where
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
{
pub circuit: CircuitData<F, C, D>,
circuit: CircuitData<F, C, D>,
proof_with_pis: [ProofWithPublicInputsTarget<D>; NUM_TABLES],
/// For each table, various inner circuits may be used depending on the initial table size.
/// This target holds the index of the circuit (within `final_circuits()`) that was used.
index_verifier_data: [Target; NUM_TABLES],
/// Public inputs used for cyclic verification. These aren't actually used for EVM root
/// proofs; the circuit has them just to match the structure of aggregation proofs.
cyclic_vk: VerifierCircuitTarget,
}
/// Data for the aggregation circuit, which is used to compress two proofs into one. Each inner
/// proof can be either an EVM root proof or another aggregation proof.
pub struct AggregationCircuitData<F, C, const D: usize>
where
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
{
circuit: CircuitData<F, C, D>,
lhs: AggregationChildTarget<D>,
rhs: AggregationChildTarget<D>,
cyclic_vk: VerifierCircuitTarget,
}
pub struct AggregationChildTarget<const D: usize> {
is_agg: BoolTarget,
agg_proof: ProofWithPublicInputsTarget<D>,
evm_proof: ProofWithPublicInputsTarget<D>,
}
impl<F, C, const D: usize> AllRecursiveCircuits<F, C, D>
@ -120,7 +146,12 @@ where
let by_table = [cpu, keccak, keccak_sponge, logic, memory];
let root = Self::create_root_circuit(&by_table, stark_config);
Self { root, by_table }
let aggregation = Self::create_aggregation_circuit(&root);
Self {
root,
aggregation,
by_table,
}
}
fn create_root_circuit(
@ -212,10 +243,58 @@ where
);
}
// We want EVM root proofs to have the exact same structure as aggregation proofs, so we add
// public inputs for cyclic verification, even though they'll be ignored.
let cyclic_vk = builder.add_verifier_data_public_inputs();
RootCircuitData {
circuit: builder.build(),
proof_with_pis: recursive_proofs,
index_verifier_data,
cyclic_vk,
}
}
fn create_aggregation_circuit(
root: &RootCircuitData<F, C, D>,
) -> AggregationCircuitData<F, C, D> {
let mut builder = CircuitBuilder::<F, D>::new(root.circuit.common.config.clone());
let cyclic_vk = builder.add_verifier_data_public_inputs();
let lhs = Self::add_agg_child(&mut builder, root);
let rhs = Self::add_agg_child(&mut builder, root);
// Pad to match the root circuit's degree.
while log2_ceil(builder.num_gates()) < root.circuit.common.degree_bits() {
builder.add_gate(NoopGate, vec![]);
}
let circuit = builder.build::<C>();
AggregationCircuitData {
circuit,
lhs,
rhs,
cyclic_vk,
}
}
fn add_agg_child(
builder: &mut CircuitBuilder<F, D>,
root: &RootCircuitData<F, C, D>,
) -> AggregationChildTarget<D> {
let common = &root.circuit.common;
let root_vk = builder.constant_verifier_data(&root.circuit.verifier_only);
let is_agg = builder.add_virtual_bool_target_safe();
let agg_proof = builder.add_virtual_proof_with_pis::<C>(common);
let evm_proof = builder.add_virtual_proof_with_pis::<C>(common);
builder
.conditionally_verify_cyclic_proof::<C>(
is_agg, &agg_proof, &evm_proof, &root_vk, common,
)
.expect("Failed to build cyclic recursion circuit");
AggregationChildTarget {
is_agg,
agg_proof,
evm_proof,
}
}
@ -229,6 +308,7 @@ where
) -> anyhow::Result<ProofWithPublicInputs<F, C, D>> {
let all_proof = prove::<F, C, D>(all_stark, config, generation_inputs, timing)?;
let mut root_inputs = PartialWitness::new();
for table in 0..NUM_TABLES {
let stark_proof = &all_proof.stark_proofs[table];
let original_degree_bits = stark_proof.proof.recover_degree_bits(config);
@ -246,8 +326,52 @@ where
);
root_inputs.set_proof_with_pis_target(&self.root.proof_with_pis[table], &shrunk_proof);
}
root_inputs
.set_verifier_data_target(&self.root.cyclic_vk, &self.root.circuit.verifier_only);
self.root.circuit.prove(root_inputs)
}
pub fn verify_root(&self, agg_proof: &ProofWithPublicInputs<F, C, D>) -> anyhow::Result<()> {
self.root.circuit.verify(agg_proof.clone())
}
pub fn prove_aggregation(
&self,
lhs_is_agg: bool,
lhs_proof: &ProofWithPublicInputs<F, C, D>,
rhs_is_agg: bool,
rhs_proof: &ProofWithPublicInputs<F, C, D>,
) -> anyhow::Result<ProofWithPublicInputs<F, C, D>> {
let mut agg_inputs = PartialWitness::new();
agg_inputs.set_bool_target(self.aggregation.lhs.is_agg, lhs_is_agg);
agg_inputs.set_proof_with_pis_target(&self.aggregation.lhs.agg_proof, lhs_proof);
agg_inputs.set_proof_with_pis_target(&self.aggregation.lhs.evm_proof, lhs_proof);
agg_inputs.set_bool_target(self.aggregation.rhs.is_agg, rhs_is_agg);
agg_inputs.set_proof_with_pis_target(&self.aggregation.rhs.agg_proof, rhs_proof);
agg_inputs.set_proof_with_pis_target(&self.aggregation.rhs.evm_proof, rhs_proof);
agg_inputs.set_verifier_data_target(
&self.aggregation.cyclic_vk,
&self.aggregation.circuit.verifier_only,
);
self.aggregation.circuit.prove(agg_inputs)
}
pub fn verify_aggregation(
&self,
agg_proof: &ProofWithPublicInputs<F, C, D>,
) -> anyhow::Result<()> {
self.aggregation.circuit.verify(agg_proof.clone())?;
check_cyclic_proof_verifier_data(
agg_proof,
&self.aggregation.circuit.verifier_only,
&self.aggregation.circuit.common,
)
}
}
struct RecursiveCircuitsForTable<F, C, const D: usize>

View File

@ -86,7 +86,10 @@ fn test_empty_txn_list() -> anyhow::Result<()> {
let all_circuits = AllRecursiveCircuits::<F, C, D>::new(&all_stark, 9..19, &config);
let root_proof = all_circuits.prove_root(&all_stark, &config, inputs, &mut timing)?;
all_circuits.root.circuit.verify(root_proof)
all_circuits.verify_root(&root_proof)?;
let agg_proof = all_circuits.prove_aggregation(false, &root_proof, false, &root_proof)?;
all_circuits.verify_aggregation(&agg_proof)
}
fn init_logger() {