From 24c201477cafd1dc3a86d7b39704532620b50a56 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 10 Feb 2022 16:14:18 +0100 Subject: [PATCH] Recursive stark test (failing) --- plonky2/src/plonk/circuit_data.rs | 10 +- plonky2/src/plonk/prover.rs | 2 +- starky/src/fibonacci_stark.rs | 451 +++++++++--------------------- starky/src/stark.rs | 2 +- 4 files changed, 144 insertions(+), 321 deletions(-) diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 7e667b8d..a113fcdb 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -96,9 +96,9 @@ impl CircuitConfig { /// Circuit data required by the prover or the verifier. pub struct CircuitData, C: GenericConfig, const D: usize> { - pub(crate) prover_only: ProverOnlyCircuitData, - pub(crate) verifier_only: VerifierOnlyCircuitData, - pub(crate) common: CommonCircuitData, + pub prover_only: ProverOnlyCircuitData, + pub verifier_only: VerifierOnlyCircuitData, + pub common: CommonCircuitData, } impl, C: GenericConfig, const D: usize> @@ -181,7 +181,7 @@ impl, C: GenericConfig, const D: usize> } /// Circuit data required by the prover, but not the verifier. -pub(crate) struct ProverOnlyCircuitData< +pub struct ProverOnlyCircuitData< F: RichField + Extendable, C: GenericConfig, const D: usize, @@ -209,7 +209,7 @@ pub(crate) struct ProverOnlyCircuitData< /// Circuit data required by the verifier, but not the prover. #[derive(Debug)] -pub(crate) struct VerifierOnlyCircuitData, const D: usize> { +pub struct VerifierOnlyCircuitData, const D: usize> { /// A commitment to each constant polynomial and each permutation polynomial. pub(crate) constants_sigmas_cap: MerkleCap, } diff --git a/plonky2/src/plonk/prover.rs b/plonky2/src/plonk/prover.rs index d49014f0..9932462b 100644 --- a/plonky2/src/plonk/prover.rs +++ b/plonky2/src/plonk/prover.rs @@ -26,7 +26,7 @@ use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_pr use crate::util::timing::TimingTree; use crate::util::transpose; -pub(crate) fn prove, C: GenericConfig, const D: usize>( +pub fn prove, C: GenericConfig, const D: usize>( prover_data: &ProverOnlyCircuitData, common_data: &CommonCircuitData, inputs: PartialWitness, diff --git a/starky/src/fibonacci_stark.rs b/starky/src/fibonacci_stark.rs index 92432c64..d4c18ff6 100644 --- a/starky/src/fibonacci_stark.rs +++ b/starky/src/fibonacci_stark.rs @@ -105,317 +105,140 @@ impl, const D: usize> Stark for FibonacciStar } } -// #[cfg(test)] -// mod tests { -// use anyhow::Result; -// use plonky2::field::extension_field::Extendable; -// use plonky2::field::field_types::Field; -// use plonky2::hash::hash_types::RichField; -// use plonky2::iop::witness::PartialWitness; -// use plonky2::plonk::circuit_builder::CircuitBuilder; -// use plonky2::plonk::circuit_data::CommonCircuitData; -// use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, PoseidonGoldilocksConfig}; -// use plonky2::plonk::proof::ProofWithPublicInputs; -// use plonky2::util::timing::TimingTree; -// -// use crate::config::StarkConfig; -// use crate::fibonacci_stark::FibonacciStark; -// use crate::proof::StarkProofWithPublicInputs; -// use crate::prover::prove; -// use crate::recursive_verifier::add_virtual_stark_proof_with_pis; -// use crate::stark_testing::test_stark_low_degree; -// use crate::verifier::verify; -// -// fn fibonacci(n: usize, x0: F, x1: F) -> F { -// (0..n).fold((x0, x1), |x, _| (x.1, x.0 + x.1)).1 -// } -// -// #[test] -// fn test_fibonacci_stark() -> Result<()> { -// const D: usize = 2; -// type C = PoseidonGoldilocksConfig; -// type F = >::F; -// type S = FibonacciStark; -// -// let config = StarkConfig::standard_fast_config(); -// let num_rows = 1 << 5; -// let public_inputs = [F::ZERO, F::ONE, fibonacci(num_rows - 1, F::ZERO, F::ONE)]; -// let stark = S::new(num_rows); -// let trace = stark.generate_trace(public_inputs[0], public_inputs[1]); -// let proof = prove::( -// stark, -// &config, -// trace, -// public_inputs, -// &mut TimingTree::default(), -// )?; -// -// verify(stark, proof, &config) -// } -// -// #[test] -// fn test_fibonacci_stark_degree() -> Result<()> { -// const D: usize = 2; -// type C = PoseidonGoldilocksConfig; -// type F = >::F; -// type S = FibonacciStark; -// -// let config = StarkConfig::standard_fast_config(); -// let num_rows = 1 << 5; -// let stark = S::new(num_rows); -// test_stark_low_degree(stark) -// } -// -// #[test] -// fn test_recursive_stark_verifier() -> Result<()> { -// init_logger(); -// const D: usize = 2; -// type C = PoseidonGoldilocksConfig; -// type F = >::F; -// type S = FibonacciStark; -// -// let config = StarkConfig::standard_fast_config(); -// let num_rows = 1 << 5; -// let public_inputs = [F::ZERO, F::ONE, fibonacci(num_rows - 1, F::ZERO, F::ONE)]; -// let stark = S::new(num_rows); -// let trace = stark.generate_trace(public_inputs[0], public_inputs[1]); -// let proof = prove::( -// stark, -// &config, -// trace, -// public_inputs, -// &mut TimingTree::default(), -// )?; -// -// let (proof, _vd, cd) = -// recursive_proof::(proof, vd, cd, &config, None, true, true)?; -// test_serialization(&proof, &cd)?; -// -// Ok(()) -// } -// -// #[test] -// fn test_recursive_recursive_verifier() -> Result<()> { -// init_logger(); -// const D: usize = 2; -// type C = PoseidonGoldilocksConfig; -// type F = >::F; -// -// let config = CircuitConfig::standard_recursion_config(); -// -// // Start with a degree 2^14 proof -// let (proof, vd, cd) = dummy_proof::(&config, 16_000)?; -// assert_eq!(cd.degree_bits, 14); -// -// // Shrink it to 2^13. -// let (proof, vd, cd) = -// recursive_proof::(proof, vd, cd, &config, Some(13), false, false)?; -// assert_eq!(cd.degree_bits, 13); -// -// // Shrink it to 2^12. -// let (proof, _vd, cd) = -// recursive_proof::(proof, vd, cd, &config, None, true, true)?; -// assert_eq!(cd.degree_bits, 12); -// -// test_serialization(&proof, &cd)?; -// -// Ok(()) -// } -// -// /// Creates a chain of recursive proofs where the last proof is made as small as reasonably -// /// possible, using a high rate, high PoW bits, etc. -// #[test] -// #[ignore] -// fn test_size_optimized_recursion() -> Result<()> { -// init_logger(); -// const D: usize = 2; -// type C = PoseidonGoldilocksConfig; -// type KC = KeccakGoldilocksConfig; -// type F = >::F; -// -// let standard_config = CircuitConfig::standard_recursion_config(); -// -// // An initial dummy proof. -// let (proof, vd, cd) = dummy_proof::(&standard_config, 4_000)?; -// assert_eq!(cd.degree_bits, 12); -// -// // A standard recursive proof. -// let (proof, vd, cd) = recursive_proof(proof, vd, cd, &standard_config, None, false, false)?; -// assert_eq!(cd.degree_bits, 12); -// -// // A high-rate recursive proof, designed to be verifiable with fewer routed wires. -// let high_rate_config = CircuitConfig { -// fri_config: FriConfig { -// rate_bits: 7, -// proof_of_work_bits: 16, -// num_query_rounds: 12, -// ..standard_config.fri_config.clone() -// }, -// ..standard_config -// }; -// let (proof, vd, cd) = -// recursive_proof::(proof, vd, cd, &high_rate_config, None, true, true)?; -// assert_eq!(cd.degree_bits, 12); -// -// // A final proof, optimized for size. -// let final_config = CircuitConfig { -// num_routed_wires: 37, -// fri_config: FriConfig { -// rate_bits: 8, -// cap_height: 0, -// proof_of_work_bits: 20, -// reduction_strategy: FriReductionStrategy::MinSize(None), -// num_query_rounds: 10, -// }, -// ..high_rate_config -// }; -// let (proof, _vd, cd) = -// recursive_proof::(proof, vd, cd, &final_config, None, true, true)?; -// assert_eq!(cd.degree_bits, 12, "final proof too large"); -// -// test_serialization(&proof, &cd)?; -// -// Ok(()) -// } -// -// #[test] -// fn test_recursive_verifier_multi_hash() -> Result<()> { -// init_logger(); -// const D: usize = 2; -// type PC = PoseidonGoldilocksConfig; -// type KC = KeccakGoldilocksConfig; -// type F = >::F; -// -// let config = CircuitConfig::standard_recursion_config(); -// let (proof, vd, cd) = dummy_proof::(&config, 4_000)?; -// -// let (proof, vd, cd) = -// recursive_proof::(proof, vd, cd, &config, None, false, false)?; -// test_serialization(&proof, &cd)?; -// -// let (proof, _vd, cd) = -// recursive_proof::(proof, vd, cd, &config, None, false, false)?; -// test_serialization(&proof, &cd)?; -// -// Ok(()) -// } -// -// /// Creates a dummy proof which should have roughly `num_dummy_gates` gates. -// fn dummy_proof, C: GenericConfig, const D: usize>( -// config: &CircuitConfig, -// num_dummy_gates: u64, -// ) -> Result<( -// ProofWithPublicInputs, -// VerifierOnlyCircuitData, -// CommonCircuitData, -// )> { -// let mut builder = CircuitBuilder::::new(config.clone()); -// for _ in 0..num_dummy_gates { -// builder.add_gate(NoopGate, vec![]); -// } -// -// let data = builder.build::(); -// let inputs = PartialWitness::new(); -// let proof = data.prove(inputs)?; -// data.verify(proof.clone())?; -// -// Ok((proof, data.verifier_only, data.common)) -// } -// -// fn recursive_proof< -// F: RichField + Extendable, -// C: GenericConfig, -// InnerC: GenericConfig, -// const D: usize, -// >( -// inner_proof: StarkProofWithPublicInputs, -// config: &StarkConfig, -// print_gate_counts: bool, -// print_timing: bool, -// ) -> Result<( -// ProofWithPublicInputs, -// VerifierOnlyCircuitData, -// CommonCircuitData, -// )> -// where -// InnerC::Hasher: AlgebraicHasher, -// { -// let mut builder = CircuitBuilder::::new(config.clone()); -// let mut pw = PartialWitness::new(); -// let degree_bits = inner_proof.proof.recover_degree_bits(config); -// let pt = add_virtual_stark_proof_with_pis(&mut builder, stark, config, degree_bits); -// pw.set_proof_with_pis_target(&pt, &inner_proof); -// -// let inner_data = VerifierCircuitTarget { -// constants_sigmas_cap: builder.add_virtual_cap(inner_cd.config.fri_config.cap_height), -// }; -// pw.set_cap_target( -// &inner_data.constants_sigmas_cap, -// &inner_vd.constants_sigmas_cap, -// ); -// -// builder.verify_proof(pt, &inner_data, &inner_cd); -// -// if print_gate_counts { -// builder.print_gate_counts(0); -// } -// -// if let Some(min_degree_bits) = min_degree_bits { -// // We don't want to pad all the way up to 2^min_degree_bits, as the builder will add a -// // few special gates afterward. So just pad to 2^(min_degree_bits - 1) + 1. Then the -// // builder will pad to the next power of two, 2^min_degree_bits. -// let min_gates = (1 << (min_degree_bits - 1)) + 1; -// for _ in builder.num_gates()..min_gates { -// builder.add_gate(NoopGate, vec![]); -// } -// } -// -// let data = builder.build::(); -// -// let mut timing = TimingTree::new("prove", Level::Debug); -// let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; -// if print_timing { -// timing.print(); -// } -// -// data.verify(proof.clone())?; -// -// Ok((proof, data.verifier_only, data.common)) -// } -// -// /// Test serialization and print some size info. -// fn test_serialization< -// F: RichField + Extendable, -// C: GenericConfig, -// const D: usize, -// >( -// proof: &ProofWithPublicInputs, -// cd: &CommonCircuitData, -// ) -> Result<()> { -// let proof_bytes = proof.to_bytes()?; -// info!("Proof length: {} bytes", proof_bytes.len()); -// let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, cd)?; -// assert_eq!(proof, &proof_from_bytes); -// -// let now = std::time::Instant::now(); -// let compressed_proof = proof.clone().compress(cd)?; -// let decompressed_compressed_proof = compressed_proof.clone().decompress(cd)?; -// info!("{:.4}s to compress proof", now.elapsed().as_secs_f64()); -// assert_eq!(proof, &decompressed_compressed_proof); -// -// let compressed_proof_bytes = compressed_proof.to_bytes()?; -// info!( -// "Compressed proof length: {} bytes", -// compressed_proof_bytes.len() -// ); -// let compressed_proof_from_bytes = -// CompressedProofWithPublicInputs::from_bytes(compressed_proof_bytes, cd)?; -// assert_eq!(compressed_proof, compressed_proof_from_bytes); -// -// Ok(()) -// } -// -// fn init_logger() { -// let _ = env_logger::builder().format_timestamp(None).try_init(); -// } -// } +#[cfg(test)] +mod tests { + use anyhow::Result; + use log::Level; + use plonky2::field::extension_field::Extendable; + use plonky2::field::field_types::Field; + use plonky2::hash::hash_types::RichField; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::{CircuitConfig, CommonCircuitData}; + use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, PoseidonGoldilocksConfig}; + use plonky2::plonk::proof::ProofWithPublicInputs; + use plonky2::util::timing::TimingTree; + use plonky2_util::reverse_index_bits_in_place; + + use crate::config::StarkConfig; + use crate::fibonacci_stark::FibonacciStark; + use crate::proof::StarkProofWithPublicInputs; + use crate::prover::prove; + use crate::recursive_verifier::{ + add_virtual_stark_proof_with_pis, set_startk_proof_with_pis_target, verify_stark_proof, + }; + use crate::stark::Stark; + use crate::stark_testing::test_stark_low_degree; + use crate::verifier::verify; + + fn fibonacci(n: usize, x0: F, x1: F) -> F { + (0..n).fold((x0, x1), |x, _| (x.1, x.0 + x.1)).1 + } + + #[test] + fn test_fibonacci_stark() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = FibonacciStark; + + let config = StarkConfig::standard_fast_config(); + let num_rows = 1 << 5; + let public_inputs = [F::ZERO, F::ONE, fibonacci(num_rows - 1, F::ZERO, F::ONE)]; + let stark = S::new(num_rows); + let trace = stark.generate_trace(public_inputs[0], public_inputs[1]); + let proof = prove::( + stark, + &config, + trace, + public_inputs, + &mut TimingTree::default(), + )?; + + verify(stark, proof, &config) + } + + #[test] + fn test_fibonacci_stark_degree() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = FibonacciStark; + + let config = StarkConfig::standard_fast_config(); + let num_rows = 1 << 5; + let stark = S::new(num_rows); + test_stark_low_degree(stark) + } + + #[test] + fn test_recursive_stark_verifier() -> Result<()> { + init_logger(); + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = FibonacciStark; + + let config = StarkConfig::standard_fast_config(); + let num_rows = 1 << 5; + let public_inputs = [F::ZERO, F::ONE, fibonacci(num_rows - 1, F::ZERO, F::ONE)]; + let stark = S::new(num_rows); + let trace = stark.generate_trace(public_inputs[0], public_inputs[1]); + let proof = prove::( + stark, + &config, + trace, + public_inputs, + &mut TimingTree::default(), + )?; + + recursive_proof::(stark, proof, &config, true, true) + } + + fn recursive_proof< + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + InnerC: GenericConfig, + const D: usize, + >( + stark: S, + inner_proof: StarkProofWithPublicInputs, + inner_config: &StarkConfig, + print_gate_counts: bool, + print_timing: bool, + ) -> Result<()> + where + InnerC::Hasher: AlgebraicHasher, + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, + { + let circuit_config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(circuit_config); + let mut pw = PartialWitness::new(); + let degree_bits = inner_proof.proof.recover_degree_bits(inner_config); + let pt = add_virtual_stark_proof_with_pis(&mut builder, stark, inner_config, degree_bits); + set_startk_proof_with_pis_target(&mut pw, &pt, &inner_proof); + + verify_stark_proof::(&mut builder, stark, pt, inner_config); + + if print_gate_counts { + builder.print_gate_counts(0); + } + + let data = builder.build::(); + + let mut timing = TimingTree::new("prove", Level::Debug); + let proof = + plonky2::plonk::prover::prove(&data.prover_only, &data.common, pw, &mut timing)?; + if print_timing { + timing.print(); + } + + data.verify(proof.clone()) + } + + fn init_logger() { + let _ = env_logger::builder().format_timestamp(None).try_init(); + } +} diff --git a/starky/src/stark.rs b/starky/src/stark.rs index 3ef976e0..888dc004 100644 --- a/starky/src/stark.rs +++ b/starky/src/stark.rs @@ -14,7 +14,7 @@ use crate::vars::StarkEvaluationVars; /// Represents a STARK system. // TODO: Add a `constraint_degree` fn that returns the maximum constraint degree. -pub trait Stark, const D: usize>: Sync { +pub trait Stark, const D: usize>: Sync + Copy { /// The total number of columns in the trace. const COLUMNS: usize; /// The number of public inputs.