From fb18232efd20b9b924107655b5476f20354f739f Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 5 Nov 2021 10:56:23 +0100 Subject: [PATCH] Generic config --- Cargo.toml | 1 + src/bin/bench_recursion.rs | 12 +- src/field/extension_field/mod.rs | 6 +- src/field/field_types.rs | 2 +- src/fri/commitment.rs | 35 +-- src/fri/proof.rs | 53 ++-- src/fri/prover.rs | 70 +++-- src/fri/recursive_verifier.rs | 33 ++- src/fri/verifier.rs | 43 +-- src/gadgets/arithmetic_extension.rs | 33 ++- src/gadgets/hash.rs | 102 +------ src/gadgets/insert.rs | 11 +- src/gadgets/interpolation.rs | 13 +- src/gadgets/permutation.rs | 22 +- src/gadgets/random_access.rs | 11 +- src/gadgets/select.rs | 11 +- src/gadgets/sorting.rs | 8 +- src/gadgets/split_base.rs | 17 +- src/gates/arithmetic.rs | 6 +- src/gates/arithmetic_u32.rs | 14 +- src/gates/base_sum.rs | 6 +- src/gates/comparison.rs | 14 +- src/gates/constant.rs | 6 +- src/gates/exponentiation.rs | 14 +- src/gates/gate_testing.rs | 10 +- src/gates/gate_tree.rs | 6 +- src/gates/gmimc.rs | 17 +- src/gates/insertion.rs | 14 +- src/gates/interpolation.rs | 13 +- src/gates/noop.rs | 6 +- src/gates/poseidon.rs | 277 ++++++++---------- src/gates/poseidon_mds.rs | 119 ++++---- src/gates/public_input.rs | 6 +- src/gates/random_access.rs | 13 +- src/gates/reducing.rs | 6 +- src/gates/switch.rs | 14 +- .../x86_64/poseidon_goldilocks_avx2_bmi2.rs | 4 +- src/hash/hash_types.rs | 37 ++- src/hash/hashing.rs | 84 ++++-- src/hash/merkle_proofs.rs | 51 ++-- src/hash/merkle_tree.rs | 46 +-- src/hash/path_compression.rs | 29 +- src/hash/poseidon.rs | 52 ++-- src/hash/poseidon_crandall.rs | 1 + src/hash/poseidon_goldilocks.rs | 149 +--------- src/iop/challenger.rs | 134 +++++---- src/iop/generator.rs | 14 +- src/iop/witness.rs | 11 +- src/plonk/circuit_builder.rs | 13 +- src/plonk/circuit_data.rs | 56 ++-- src/plonk/config.rs | 267 +++++++++++++++++ src/plonk/get_challenges.rs | 59 ++-- src/plonk/mod.rs | 1 + src/plonk/proof.rs | 106 ++++--- src/plonk/prover.rs | 68 +++-- src/plonk/recursive_verifier.rs | 93 +++--- src/plonk/vanishing_poly.rs | 21 +- src/plonk/verifier.rs | 23 +- src/util/reducing.rs | 19 +- src/util/serialization.rs | 192 ++++++------ 60 files changed, 1433 insertions(+), 1141 deletions(-) create mode 100644 src/hash/poseidon_crandall.rs create mode 100644 src/plonk/config.rs diff --git a/Cargo.toml b/Cargo.toml index 71970b41..a55bd899 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ unroll = "0.1.5" anyhow = "1.0.40" serde = { version = "1.0", features = ["derive"] } serde_cbor = "0.11.1" +keccak-hash = "0.8.0" static_assertions = "1.1.0" [dev-dependencies] diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index e9fc25a4..ee602731 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -10,6 +10,7 @@ use plonky2::hash::hashing::SPONGE_WIDTH; use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::CircuitConfig; +use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; fn main() -> Result<()> { // Set the default log filter. This can be overridden using the `RUST_LOG` environment variable, @@ -18,10 +19,13 @@ fn main() -> Result<()> { // change this to info or warn later. env_logger::Builder::from_env(Env::default().default_filter_or("debug")).init(); - bench_prove::() + bench_prove() } -fn bench_prove, const D: usize>() -> Result<()> { +fn bench_prove() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let config = CircuitConfig { num_wires: 126, num_routed_wires: 33, @@ -46,14 +50,14 @@ fn bench_prove, const D: usize>() -> Result<()> { let mut state = [zero; SPONGE_WIDTH]; for _ in 0..10000 { - state = builder.permute(state); + state = builder.permute::<>::InnerHasher>(state); } // Random other gates. builder.add(zero, zero); builder.add_extension(zero_ext, zero_ext); - let circuit = builder.build(); + let circuit = builder.build::(); let proof_with_pis = circuit.prove(inputs)?; let proof_bytes = serde_cbor::to_vec(&proof_with_pis).unwrap(); info!("Proof length: {} bytes", proof_bytes.len()); diff --git a/src/field/extension_field/mod.rs b/src/field/extension_field/mod.rs index 08443386..f8322e6d 100644 --- a/src/field/extension_field/mod.rs +++ b/src/field/extension_field/mod.rs @@ -1,6 +1,6 @@ use std::convert::TryInto; -use crate::field::field_types::{Field, PrimeField}; +use crate::field::field_types::{Field, RichField}; pub mod algebra; pub mod quadratic; @@ -61,7 +61,7 @@ pub trait Frobenius: OEF { } } -pub trait Extendable: PrimeField + Sized { +pub trait Extendable: RichField + Sized { type Extension: Field + OEF + Frobenius + From; const W: Self; @@ -76,7 +76,7 @@ pub trait Extendable: PrimeField + Sized { const EXT_POWER_OF_TWO_GENERATOR: [Self; D]; } -impl + FieldExtension<1, BaseField = F>> Extendable<1> for F { +impl + FieldExtension<1, BaseField = F>> Extendable<1> for F { type Extension = F; const W: Self = F::ZERO; const DTH_ROOT: Self = F::ZERO; diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 4fe10b17..28aa1e97 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -16,7 +16,7 @@ use crate::hash::poseidon::Poseidon; use crate::util::bits_u64; /// A prime order field with the features we need to use it as a base field in our argument system. -pub trait RichField: PrimeField + GMiMC<12> + Poseidon<12> {} +pub trait RichField: PrimeField + GMiMC<12> + Poseidon {} /// A finite field. pub trait Field: diff --git a/src/fri/commitment.rs b/src/fri/commitment.rs index 8233a293..1032d2b0 100644 --- a/src/fri/commitment.rs +++ b/src/fri/commitment.rs @@ -2,12 +2,13 @@ use rayon::prelude::*; use crate::field::extension_field::Extendable; use crate::field::fft::FftRootTable; -use crate::field::field_types::{Field, RichField}; +use crate::field::field_types::Field; use crate::fri::proof::FriProof; use crate::fri::prover::fri_proof; use crate::hash::merkle_tree::MerkleTree; use crate::iop::challenger::Challenger; use crate::plonk::circuit_data::CommonCircuitData; +use crate::plonk::config::GenericConfig; use crate::plonk::plonk_common::PlonkPolynomials; use crate::plonk::proof::OpeningSet; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; @@ -20,15 +21,17 @@ use crate::util::{log2_strict, reverse_bits, reverse_index_bits_in_place, transp pub const SALT_SIZE: usize = 2; /// Represents a batch FRI based commitment to a list of polynomials. -pub struct PolynomialBatchCommitment { +pub struct PolynomialBatchCommitment, C: GenericConfig, const D: usize> { pub polynomials: Vec>, - pub merkle_tree: MerkleTree, + pub merkle_tree: MerkleTree, pub degree_log: usize, pub rate_bits: usize, pub blinding: bool, } -impl PolynomialBatchCommitment { +impl, C: GenericConfig, const D: usize> + PolynomialBatchCommitment +{ /// Creates a list polynomial commitment for the polynomials interpolating the values in `values`. pub(crate) fn from_values( values: Vec>, @@ -122,16 +125,13 @@ impl PolynomialBatchCommitment { /// Takes the commitments to the constants - sigmas - wires - zs - quotient — polynomials, /// and an opening point `zeta` and produces a batched opening proof + opening set. - pub(crate) fn open_plonk( + pub(crate) fn open_plonk( commitments: &[&Self; 4], zeta: F::Extension, - challenger: &mut Challenger, - common_data: &CommonCircuitData, + challenger: &mut Challenger, + common_data: &CommonCircuitData, timing: &mut TimingTree, - ) -> (FriProof, OpeningSet) - where - F: RichField + Extendable, - { + ) -> (FriProof, OpeningSet) { let config = &common_data.config; assert!(D > 1, "Not implemented for D=1."); let degree_log = commitments[0].degree_log; @@ -159,7 +159,7 @@ impl PolynomialBatchCommitment { ); challenger.observe_opening_set(&os); - let alpha = challenger.get_extension_challenge(); + let alpha = challenger.get_extension_challenge::(); let mut alpha = ReducingFactor::new(alpha); // Final low-degree polynomial that goes into FRI. @@ -210,13 +210,13 @@ impl PolynomialBatchCommitment { let fri_proof = fri_proof( &commitments - .par_iter() + .iter() .map(|c| &c.merkle_tree) .collect::>(), lde_final_poly, lde_final_values, challenger, - &common_data, + common_data, timing, ); @@ -225,13 +225,10 @@ impl PolynomialBatchCommitment { /// Given `points=(x_i)`, `evals=(y_i)` and `poly=P` with `P(x_i)=y_i`, computes the polynomial /// `Q=(P-I)/Z` where `I` interpolates `(x_i, y_i)` and `Z` is the vanishing polynomial on `(x_i)`. - fn compute_quotient( + fn compute_quotient( points: [F::Extension; N], poly: PolynomialCoeffs, - ) -> PolynomialCoeffs - where - F: Extendable, - { + ) -> PolynomialCoeffs { let quotient = if N == 1 { poly.divide_by_linear(points[0]).0 } else if N == 2 { diff --git a/src/fri/proof.rs b/src/fri/proof.rs index f6875fcc..c3e69971 100644 --- a/src/fri/proof.rs +++ b/src/fri/proof.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{flatten, unflatten, Extendable}; -use crate::field::field_types::{Field, RichField}; +use crate::field::field_types::RichField; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; use crate::hash::hash_types::MerkleCapTarget; use crate::hash::merkle_proofs::{MerkleProof, MerkleProofTarget}; @@ -13,6 +13,7 @@ use crate::hash::merkle_tree::MerkleCap; use crate::hash::path_compression::{compress_merkle_proofs, decompress_merkle_proofs}; use crate::iop::target::Target; use crate::plonk::circuit_data::CommonCircuitData; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::plonk_common::PolynomialsIndexBlinding; use crate::plonk::proof::{FriInferredElements, ProofChallenges}; use crate::polynomial::polynomial::PolynomialCoeffs; @@ -20,9 +21,9 @@ use crate::polynomial::polynomial::PolynomialCoeffs; /// Evaluations and Merkle proof produced by the prover in a FRI query step. #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] -pub struct FriQueryStep, const D: usize> { +pub struct FriQueryStep, H: Hasher, const D: usize> { pub evals: Vec, - pub merkle_proof: MerkleProof, + pub merkle_proof: MerkleProof, } #[derive(Clone)] @@ -35,11 +36,11 @@ pub struct FriQueryStepTarget { /// before they are combined into a composition polynomial. #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] -pub struct FriInitialTreeProof { - pub evals_proofs: Vec<(Vec, MerkleProof)>, +pub struct FriInitialTreeProof> { + pub evals_proofs: Vec<(Vec, MerkleProof)>, } -impl FriInitialTreeProof { +impl> FriInitialTreeProof { pub(crate) fn unsalted_evals( &self, polynomials: PolynomialsIndexBlinding, @@ -69,9 +70,9 @@ impl FriInitialTreeProofTarget { /// Proof for a FRI query round. #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] -pub struct FriQueryRound, const D: usize> { - pub initial_trees_proof: FriInitialTreeProof, - pub steps: Vec>, +pub struct FriQueryRound, H: Hasher, const D: usize> { + pub initial_trees_proof: FriInitialTreeProof, + pub steps: Vec>, } #[derive(Clone)] @@ -83,22 +84,22 @@ pub struct FriQueryRoundTarget { /// Compressed proof of the FRI query rounds. #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] -pub struct CompressedFriQueryRounds, const D: usize> { +pub struct CompressedFriQueryRounds, H: Hasher, const D: usize> { /// Query indices. pub indices: Vec, /// Map from initial indices `i` to the `FriInitialProof` for the `i`th leaf. - pub initial_trees_proofs: HashMap>, + pub initial_trees_proofs: HashMap>, /// For each FRI query step, a map from indices `i` to the `FriQueryStep` for the `i`th leaf. - pub steps: Vec>>, + pub steps: Vec>>, } #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] -pub struct FriProof, const D: usize> { +pub struct FriProof, H: Hasher, const D: usize> { /// A Merkle cap for each reduced polynomial in the commit phase. - pub commit_phase_merkle_caps: Vec>, + pub commit_phase_merkle_caps: Vec>, /// Query rounds proofs - pub query_round_proofs: Vec>, + pub query_round_proofs: Vec>, /// The final polynomial in coefficient form. pub final_poly: PolynomialCoeffs, /// Witness showing that the prover did PoW. @@ -114,24 +115,24 @@ pub struct FriProofTarget { #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] -pub struct CompressedFriProof, const D: usize> { +pub struct CompressedFriProof, H: Hasher, const D: usize> { /// A Merkle cap for each reduced polynomial in the commit phase. - pub commit_phase_merkle_caps: Vec>, + pub commit_phase_merkle_caps: Vec>, /// Compressed query rounds proof. - pub query_round_proofs: CompressedFriQueryRounds, + pub query_round_proofs: CompressedFriQueryRounds, /// The final polynomial in coefficient form. pub final_poly: PolynomialCoeffs, /// Witness showing that the prover did PoW. pub pow_witness: F, } -impl, const D: usize> FriProof { +impl, H: Hasher, const D: usize> FriProof { /// Compress all the Merkle paths in the FRI proof and remove duplicate indices. - pub fn compress( + pub fn compress>( self, indices: &[usize], - common_data: &CommonCircuitData, - ) -> CompressedFriProof { + common_data: &CommonCircuitData, + ) -> CompressedFriProof { let FriProof { commit_phase_merkle_caps, query_round_proofs, @@ -231,14 +232,14 @@ impl, const D: usize> FriProof { } } -impl, const D: usize> CompressedFriProof { +impl, H: Hasher, const D: usize> CompressedFriProof { /// Decompress all the Merkle paths in the FRI proof and reinsert duplicate indices. - pub(crate) fn decompress( + pub(crate) fn decompress>( self, challenges: &ProofChallenges, fri_inferred_elements: FriInferredElements, - common_data: &CommonCircuitData, - ) -> FriProof { + common_data: &CommonCircuitData, + ) -> FriProof { let CompressedFriProof { commit_phase_merkle_caps, query_round_proofs, diff --git a/src/fri/prover.rs b/src/fri/prover.rs index ec58fc14..56040682 100644 --- a/src/fri/prover.rs +++ b/src/fri/prover.rs @@ -1,14 +1,13 @@ use rayon::prelude::*; use crate::field::extension_field::{flatten, unflatten, Extendable}; -use crate::field::field_types::RichField; use crate::fri::proof::{FriInitialTreeProof, FriProof, FriQueryRound, FriQueryStep}; use crate::fri::FriConfig; use crate::hash::hash_types::HashOut; -use crate::hash::hashing::hash_n_to_1; use crate::hash::merkle_tree::MerkleTree; use crate::iop::challenger::Challenger; use crate::plonk::circuit_data::CommonCircuitData; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::plonk_common::reduce_with_powers; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; @@ -16,16 +15,16 @@ use crate::util::reverse_index_bits_in_place; use crate::util::timing::TimingTree; /// Builds a FRI proof. -pub fn fri_proof, const D: usize>( - initial_merkle_trees: &[&MerkleTree], +pub fn fri_proof, C: GenericConfig, const D: usize>( + initial_merkle_trees: &[&MerkleTree], // Coefficients of the polynomial on which the LDT is performed. Only the first `1/rate` coefficients are non-zero. lde_polynomial_coeffs: PolynomialCoeffs, // Evaluation of the polynomial on the large domain. lde_polynomial_values: PolynomialValues, - challenger: &mut Challenger, - common_data: &CommonCircuitData, + challenger: &mut Challenger, + common_data: &CommonCircuitData, timing: &mut TimingTree, -) -> FriProof { +) -> FriProof { let n = lde_polynomial_values.values.len(); assert_eq!(lde_polynomial_coeffs.coeffs.len(), n); @@ -42,11 +41,11 @@ pub fn fri_proof, const D: usize>( ); // PoW phase - let current_hash = challenger.get_hash(); + let current_hash = challenger.get_hash::(); let pow_witness = timed!( timing, "find proof-of-work witness", - fri_proof_of_work(current_hash, &common_data.config.fri_config) + fri_proof_of_work::(current_hash, &common_data.config.fri_config) ); // Query phase @@ -61,12 +60,15 @@ pub fn fri_proof, const D: usize>( } } -fn fri_committed_trees, const D: usize>( +fn fri_committed_trees, C: GenericConfig, const D: usize>( mut coeffs: PolynomialCoeffs, mut values: PolynomialValues, - challenger: &mut Challenger, - common_data: &CommonCircuitData, -) -> (Vec>, PolynomialCoeffs) { + challenger: &mut Challenger, + common_data: &CommonCircuitData, +) -> ( + Vec>, + PolynomialCoeffs, +) { let config = &common_data.config; let mut trees = Vec::new(); @@ -81,12 +83,12 @@ fn fri_committed_trees, const D: usize>( .par_chunks(arity) .map(|chunk: &[F::Extension]| flatten(chunk)) .collect(); - let tree = MerkleTree::new(chunked_values, config.cap_height); + let tree = MerkleTree::::new(chunked_values, config.cap_height); challenger.observe_cap(&tree.cap); trees.push(tree); - let beta = challenger.get_extension_challenge(); + let beta = challenger.get_extension_challenge::(); // P(x) = sum_{i, const D: usize>( (trees, coeffs) } -fn fri_proof_of_work(current_hash: HashOut, config: &FriConfig) -> F { +fn fri_proof_of_work, C: GenericConfig, const D: usize>( + current_hash: HashOut, + config: &FriConfig, +) -> F { (0..=F::NEG_ONE.to_canonical_u64()) .into_par_iter() .find_any(|&i| { - hash_n_to_1( + C::InnerHasher::hash( current_hash .elements .iter() @@ -119,35 +124,36 @@ fn fri_proof_of_work(current_hash: HashOut, config: &FriConfig) .collect(), false, ) - .to_canonical_u64() - .leading_zeros() + .elements[0] + .to_canonical_u64() + .leading_zeros() >= config.proof_of_work_bits + (64 - F::order().bits()) as u32 }) .map(F::from_canonical_u64) .expect("Proof of work failed. This is highly unlikely!") } -fn fri_prover_query_rounds, const D: usize>( - initial_merkle_trees: &[&MerkleTree], - trees: &[MerkleTree], - challenger: &mut Challenger, +fn fri_prover_query_rounds, C: GenericConfig, const D: usize>( + initial_merkle_trees: &[&MerkleTree], + trees: &[MerkleTree], + challenger: &mut Challenger, n: usize, - common_data: &CommonCircuitData, -) -> Vec> { + common_data: &CommonCircuitData, +) -> Vec> { (0..common_data.config.fri_config.num_query_rounds) .map(|_| fri_prover_query_round(initial_merkle_trees, trees, challenger, n, common_data)) .collect() } -fn fri_prover_query_round, const D: usize>( - initial_merkle_trees: &[&MerkleTree], - trees: &[MerkleTree], - challenger: &mut Challenger, +fn fri_prover_query_round, C: GenericConfig, const D: usize>( + initial_merkle_trees: &[&MerkleTree], + trees: &[MerkleTree], + challenger: &mut Challenger, n: usize, - common_data: &CommonCircuitData, -) -> FriQueryRound { + common_data: &CommonCircuitData, +) -> FriQueryRound { let mut query_steps = Vec::new(); - let x = challenger.get_challenge(); + let x = challenger.get_challenge::(); let mut x_index = x.to_canonical_u64() as usize % n; let initial_proof = initial_merkle_trees .iter() diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index 443bd7c9..14abe937 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -11,6 +11,7 @@ use crate::iop::challenger::RecursiveChallenger; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CommonCircuitData; +use crate::plonk::config::{AlgebraicConfig, AlgebraicHasher, GenericConfig}; use crate::plonk::plonk_common::PlonkPolynomials; use crate::plonk::proof::OpeningSetTarget; use crate::util::reducing::ReducingFactorTarget; @@ -88,23 +89,23 @@ impl, const D: usize> CircuitBuilder { ); } - fn fri_verify_proof_of_work( + fn fri_verify_proof_of_work>( &mut self, proof: &FriProofTarget, - challenger: &mut RecursiveChallenger, + challenger: &mut RecursiveChallenger, config: &FriConfig, ) { let mut inputs = challenger.get_hash(self).elements.to_vec(); inputs.push(proof.pow_witness); - let hash = self.hash_n_to_m(inputs, 1, false)[0]; + let hash = self.hash_n_to_m::(inputs, 1, false)[0]; self.assert_leading_zeros( hash, config.proof_of_work_bits + (64 - F::order().bits()) as u32, ); } - pub fn verify_fri_proof( + pub fn verify_fri_proof>( &mut self, // Openings of the PLONK polynomials. os: &OpeningSetTarget, @@ -112,8 +113,8 @@ impl, const D: usize> CircuitBuilder { zeta: ExtensionTarget, initial_merkle_caps: &[MerkleCapTarget], proof: &FriProofTarget, - challenger: &mut RecursiveChallenger, - common_data: &CommonCircuitData, + challenger: &mut RecursiveChallenger, + common_data: &CommonCircuitData, ) { let config = &common_data.config; @@ -152,7 +153,7 @@ impl, const D: usize> CircuitBuilder { with_context!( self, "check PoW", - self.fri_verify_proof_of_work(proof, challenger, &config.fri_config) + self.fri_verify_proof_of_work::(proof, challenger, &config.fri_config) ); // Check that parameters are coherent. @@ -205,7 +206,7 @@ impl, const D: usize> CircuitBuilder { } } - fn fri_verify_initial_proof( + fn fri_verify_initial_proof>( &mut self, x_index_bits: &[BoolTarget], proof: &FriInitialTreeProofTarget, @@ -221,7 +222,7 @@ impl, const D: usize> CircuitBuilder { with_context!( self, &format!("verify {}'th initial Merkle proof", i), - self.verify_merkle_proof_with_cap_index( + self.verify_merkle_proof_with_cap_index::( evals.clone(), x_index_bits, cap_index, @@ -232,14 +233,14 @@ impl, const D: usize> CircuitBuilder { } } - fn fri_combine_initial( + fn fri_combine_initial>( &mut self, proof: &FriInitialTreeProofTarget, alpha: ExtensionTarget, subgroup_x: Target, vanish_zeta: ExtensionTarget, precomputed_reduced_evals: PrecomputedReducedEvalsTarget, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> ExtensionTarget { assert!(D > 1, "Not implemented for D=1."); let config = self.config.clone(); @@ -301,18 +302,18 @@ impl, const D: usize> CircuitBuilder { sum } - fn fri_verifier_query_round( + fn fri_verifier_query_round>( &mut self, zeta: ExtensionTarget, alpha: ExtensionTarget, precomputed_reduced_evals: PrecomputedReducedEvalsTarget, initial_merkle_caps: &[MerkleCapTarget], proof: &FriProofTarget, - challenger: &mut RecursiveChallenger, + challenger: &mut RecursiveChallenger, n: usize, betas: &[ExtensionTarget], round_proof: &FriQueryRoundTarget, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) { let n_log = log2_strict(n); // TODO: Do we need to range check `x_index` to a target smaller than `p`? @@ -323,7 +324,7 @@ impl, const D: usize> CircuitBuilder { with_context!( self, "check FRI initial proof", - self.fri_verify_initial_proof( + self.fri_verify_initial_proof::( &x_index_bits, &round_proof.initial_trees_proof, initial_merkle_caps, @@ -392,7 +393,7 @@ impl, const D: usize> CircuitBuilder { with_context!( self, "verify FRI round Merkle proof.", - self.verify_merkle_proof_with_cap_index( + self.verify_merkle_proof_with_cap_index::( flatten_target(evals), &coset_index_bits, cap_index, diff --git a/src/fri/verifier.rs b/src/fri/verifier.rs index add03a9d..1b067a42 100644 --- a/src/fri/verifier.rs +++ b/src/fri/verifier.rs @@ -8,6 +8,7 @@ use crate::fri::FriConfig; use crate::hash::merkle_proofs::verify_merkle_proof; use crate::hash::merkle_tree::MerkleCap; use crate::plonk::circuit_data::CommonCircuitData; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::plonk_common::PlonkPolynomials; use crate::plonk::proof::{OpeningSet, ProofChallenges}; use crate::util::reducing::ReducingFactor; @@ -55,13 +56,17 @@ pub(crate) fn fri_verify_proof_of_work, const D: us Ok(()) } -pub(crate) fn verify_fri_proof, const D: usize>( +pub(crate) fn verify_fri_proof< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( // Openings of the PLONK polynomials. os: &OpeningSet, challenges: &ProofChallenges, - initial_merkle_caps: &[MerkleCap], - proof: &FriProof, - common_data: &CommonCircuitData, + initial_merkle_caps: &[MerkleCap], + proof: &FriProof, + common_data: &CommonCircuitData, ) -> Result<()> { let config = &common_data.config; ensure!( @@ -88,7 +93,7 @@ pub(crate) fn verify_fri_proof, const D: usize>( .iter() .zip(&proof.query_round_proofs) { - fri_verifier_query_round( + fri_verifier_query_round::( challenges, precomputed_reduced_evals, initial_merkle_caps, @@ -103,13 +108,13 @@ pub(crate) fn verify_fri_proof, const D: usize>( Ok(()) } -fn fri_verify_initial_proof( +fn fri_verify_initial_proof>( x_index: usize, - proof: &FriInitialTreeProof, - initial_merkle_caps: &[MerkleCap], + proof: &FriInitialTreeProof, + initial_merkle_caps: &[MerkleCap], ) -> Result<()> { for ((evals, merkle_proof), cap) in proof.evals_proofs.iter().zip(initial_merkle_caps) { - verify_merkle_proof(evals.clone(), x_index, cap, merkle_proof)?; + verify_merkle_proof::(evals.clone(), x_index, cap, merkle_proof)?; } Ok(()) @@ -146,13 +151,13 @@ impl, const D: usize> PrecomputedReducedEvals { } } -pub(crate) fn fri_combine_initial, const D: usize>( - proof: &FriInitialTreeProof, +pub(crate) fn fri_combine_initial, C: GenericConfig, const D: usize>( + proof: &FriInitialTreeProof, alpha: F::Extension, zeta: F::Extension, subgroup_x: F, precomputed_reduced_evals: PrecomputedReducedEvals, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> F::Extension { let config = &common_data.config; assert!(D > 1, "Not implemented for D=1."); @@ -207,17 +212,17 @@ pub(crate) fn fri_combine_initial, const D: usize>( sum } -fn fri_verifier_query_round, const D: usize>( +fn fri_verifier_query_round, C: GenericConfig, const D: usize>( challenges: &ProofChallenges, precomputed_reduced_evals: PrecomputedReducedEvals, - initial_merkle_caps: &[MerkleCap], - proof: &FriProof, + initial_merkle_caps: &[MerkleCap], + proof: &FriProof, mut x_index: usize, n: usize, - round_proof: &FriQueryRound, - common_data: &CommonCircuitData, + round_proof: &FriQueryRound, + common_data: &CommonCircuitData, ) -> Result<()> { - fri_verify_initial_proof( + fri_verify_initial_proof::( x_index, &round_proof.initial_trees_proof, initial_merkle_caps, @@ -263,7 +268,7 @@ fn fri_verifier_query_round, const D: usize>( challenges.fri_betas[i], ); - verify_merkle_proof( + verify_merkle_proof::( flatten(evals), coset_index, &proof.commit_phase_merkle_caps[i], diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 24499760..8231f204 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -563,19 +563,20 @@ mod tests { use anyhow::Result; use crate::field::extension_field::algebra::ExtensionAlgebra; - use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, KeccakGoldilocksConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; #[test] fn test_mul_many() -> Result<()> { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; let config = CircuitConfig::standard_recursion_config(); @@ -600,7 +601,7 @@ mod tests { builder.connect_extension(mul0, mul1); builder.connect_extension(mul1, mul2); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) @@ -608,9 +609,10 @@ mod tests { #[test] fn test_div_extension() -> Result<()> { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; let config = CircuitConfig::standard_recursion_zk_config(); @@ -628,7 +630,7 @@ mod tests { builder.connect_extension(zt, comp_zt); builder.connect_extension(zt, comp_zt_unsafe); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) @@ -636,17 +638,18 @@ mod tests { #[test] fn test_mul_algebra() -> Result<()> { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; + const D: usize = 2; + type C = KeccakGoldilocksConfig; + type F = >::F; + type FF = >::FE; let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = FF::rand_vec(4); - let y = FF::rand_vec(4); + let x = FF::rand_vec(D); + let y = FF::rand_vec(D); let xa = ExtensionAlgebra(x.try_into().unwrap()); let ya = ExtensionAlgebra(y.try_into().unwrap()); let za = xa * ya; @@ -659,7 +662,7 @@ mod tests { builder.connect_extension(zt.0[i], comp_zt.0[i]); } - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) diff --git a/src/gadgets/hash.rs b/src/gadgets/hash.rs index 99da9e1e..931b0f3a 100644 --- a/src/gadgets/hash.rs +++ b/src/gadgets/hash.rs @@ -5,110 +5,30 @@ use crate::field::field_types::RichField; use crate::gates::gmimc::GMiMCGate; use crate::gates::poseidon::PoseidonGate; use crate::hash::gmimc::GMiMC; -use crate::hash::hashing::{HashFamily, HASH_FAMILY}; +use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::poseidon::Poseidon; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::config::AlgebraicHasher; impl, const D: usize> CircuitBuilder { - pub fn permute(&mut self, inputs: [Target; W]) -> [Target; W] - where - F: GMiMC + Poseidon, - [(); W - 1]: , - { + pub fn permute>( + &mut self, + inputs: [Target; SPONGE_WIDTH], + ) -> [Target; SPONGE_WIDTH] { // We don't want to swap any inputs, so set that wire to 0. let _false = self._false(); - self.permute_swapped(inputs, _false) + self.permute_swapped::(inputs, _false) } /// Conditionally swap two chunks of the inputs (useful in verifying Merkle proofs), then apply /// a cryptographic permutation. - pub(crate) fn permute_swapped( + pub(crate) fn permute_swapped>( &mut self, - inputs: [Target; W], + inputs: [Target; SPONGE_WIDTH], swap: BoolTarget, - ) -> [Target; W] - where - F: GMiMC + Poseidon, - [(); W - 1]: , - { - match HASH_FAMILY { - HashFamily::GMiMC => self.gmimc_permute_swapped(inputs, swap), - HashFamily::Poseidon => self.poseidon_permute_swapped(inputs, swap), - } - } - - /// Conditionally swap two chunks of the inputs (useful in verifying Merkle proofs), then apply - /// the GMiMC permutation. - pub(crate) fn gmimc_permute_swapped( - &mut self, - inputs: [Target; W], - swap: BoolTarget, - ) -> [Target; W] - where - F: GMiMC, - { - let gate_type = GMiMCGate::::new(); - let gate = self.add_gate(gate_type, vec![]); - - let swap_wire = GMiMCGate::::WIRE_SWAP; - let swap_wire = Target::wire(gate, swap_wire); - self.connect(swap.target, swap_wire); - - // Route input wires. - for i in 0..W { - let in_wire = GMiMCGate::::wire_input(i); - let in_wire = Target::Wire(Wire { - gate, - input: in_wire, - }); - self.connect(inputs[i], in_wire); - } - - // Collect output wires. - (0..W) - .map(|i| { - Target::Wire(Wire { - gate, - input: GMiMCGate::::wire_output(i), - }) - }) - .collect::>() - .try_into() - .unwrap() - } - - /// Conditionally swap two chunks of the inputs (useful in verifying Merkle proofs), then apply - /// the Poseidon permutation. - pub(crate) fn poseidon_permute_swapped( - &mut self, - inputs: [Target; W], - swap: BoolTarget, - ) -> [Target; W] - where - F: Poseidon, - [(); W - 1]: , - { - let gate_type = PoseidonGate::::new(); - let gate = self.add_gate(gate_type, vec![]); - - let swap_wire = PoseidonGate::::WIRE_SWAP; - let swap_wire = Target::wire(gate, swap_wire); - self.connect(swap.target, swap_wire); - - // Route input wires. - for i in 0..W { - let in_wire = PoseidonGate::::wire_input(i); - let in_wire = Target::wire(gate, in_wire); - self.connect(inputs[i], in_wire); - } - - // Collect output wires. - (0..W) - .map(|i| Target::wire(gate, PoseidonGate::::wire_output(i))) - .collect::>() - .try_into() - .unwrap() + ) -> [Target; SPONGE_WIDTH] { + H::permute_swapped(inputs, swap, self) } } diff --git a/src/gadgets/insert.rs b/src/gadgets/insert.rs index c6b463ea..77ecca3c 100644 --- a/src/gadgets/insert.rs +++ b/src/gadgets/insert.rs @@ -48,6 +48,7 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; fn real_insert( @@ -61,12 +62,14 @@ mod tests { } fn test_insert_given_len(len_log: usize) -> Result<()> { - type F = GoldilocksField; - type FF = QuadraticExtension; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; let len = 1 << len_log; let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::::new(config); let v = (0..len - 1) .map(|_| builder.constant_extension(FF::rand())) .collect::>(); @@ -84,7 +87,7 @@ mod tests { } } - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index f72ed0fc..2e364da0 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -43,15 +43,18 @@ mod tests { use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; #[test] fn test_interpolate() -> Result<()> { - type F = GoldilocksField; - type FF = QuarticExtension; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::::new(config); let len = 4; let points = (0..len) @@ -60,7 +63,7 @@ mod tests { let homogeneous_points = points .iter() - .map(|&(a, b)| (>::from_basefield(a), b)) + .map(|&(a, b)| (>::from_basefield(a), b)) .collect::>(); let true_interpolant = interpolant(&homogeneous_points); @@ -79,7 +82,7 @@ mod tests { let true_eval_target = builder.constant_extension(true_eval); builder.connect_extension(eval, true_eval_target); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index fd4a897f..52069bfb 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -391,11 +391,13 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; fn test_permutation_good(size: usize) -> Result<()> { - type F = GoldilocksField; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let config = CircuitConfig::standard_recursion_config(); @@ -412,15 +414,16 @@ mod tests { builder.assert_permutation(a, b); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) } fn test_permutation_duplicates(size: usize) -> Result<()> { - type F = GoldilocksField; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let config = CircuitConfig::standard_recursion_config(); @@ -441,15 +444,16 @@ mod tests { builder.assert_permutation(a, b); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) } fn test_permutation_bad(size: usize) -> Result<()> { - type F = GoldilocksField; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let config = CircuitConfig::standard_recursion_config(); @@ -469,7 +473,7 @@ mod tests { builder.assert_permutation(a, b); - let data = builder.build(); + let data = builder.build::(); data.prove(pw)?; Ok(()) diff --git a/src/gadgets/random_access.rs b/src/gadgets/random_access.rs index 398c516f..c7c7b8bf 100644 --- a/src/gadgets/random_access.rs +++ b/src/gadgets/random_access.rs @@ -92,15 +92,18 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; fn test_random_access_given_len(len_log: usize) -> Result<()> { - type F = GoldilocksField; - type FF = QuarticExtension; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; let len = 1 << len_log; let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::::new(config); let vec = FF::rand_vec(len); let v: Vec<_> = vec.iter().map(|x| builder.constant_extension(*x)).collect(); @@ -110,7 +113,7 @@ mod tests { builder.random_access_extension(it, elem, v.clone()); } - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) diff --git a/src/gadgets/select.rs b/src/gadgets/select.rs index 1db41fb8..4fa98347 100644 --- a/src/gadgets/select.rs +++ b/src/gadgets/select.rs @@ -47,15 +47,18 @@ mod tests { use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; #[test] fn test_select() -> Result<()> { - type F = GoldilocksField; - type FF = QuarticExtension; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; let config = CircuitConfig::standard_recursion_config(); let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::::new(config); let (x, y) = (FF::rand(), FF::rand()); let xt = builder.add_virtual_extension_target(); @@ -72,7 +75,7 @@ mod tests { builder.connect_extension(should_be_x, xt); builder.connect_extension(should_be_y, yt); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 72dcf273..e283d983 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -180,11 +180,13 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; fn test_sorting(size: usize, address_bits: usize, timestamp_bits: usize) -> Result<()> { - type F = GoldilocksField; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let config = CircuitConfig::standard_recursion_config(); @@ -237,7 +239,7 @@ mod tests { pw.set_target(output_ops[i].value, input_ops_sorted[i].3); } - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw).unwrap(); verify(proof, &data.verifier_only, &data.common) diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index 30bdea6a..524ed46f 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -107,14 +107,17 @@ mod tests { use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; #[test] fn test_split_base() -> Result<()> { - type F = GoldilocksField; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::::new(config); let x = F::from_canonical_usize(0b110100000); // 416 = 1532 in base 6. let xt = builder.constant(x); let limbs = builder.split_le_base::<6>(xt, 24); @@ -128,7 +131,7 @@ mod tests { builder.connect(limbs[3], one); builder.assert_leading_zeros(xt, 64 - 9); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; @@ -137,10 +140,12 @@ mod tests { #[test] fn test_base_sum() -> Result<()> { - type F = GoldilocksField; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::::new(config); let n = thread_rng().gen_range(0..(1 << 10)); let x = builder.constant(F::from_canonical_usize(n)); @@ -161,7 +166,7 @@ mod tests { builder.connect(x, y); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 95b48e2f..f04ef7f7 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -209,6 +209,7 @@ mod tests { use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] fn low_degree() { @@ -219,8 +220,11 @@ mod tests { #[test] fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let gate = ArithmeticExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config()); - test_eval_fns::(gate) + test_eval_fns::(gate) } } diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs index 6564a876..0aef9feb 100644 --- a/src/gates/arithmetic_u32.rs +++ b/src/gates/arithmetic_u32.rs @@ -330,13 +330,13 @@ mod tests { use anyhow::Result; use rand::Rng; - use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::vars::EvaluationVars; #[test] @@ -348,16 +348,20 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(U32ArithmeticGate:: { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(U32ArithmeticGate:: { _phantom: PhantomData, }) } #[test] fn test_gate_constraint() { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; fn get_wires( multiplicands_0: Vec, diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index 99ee05eb..98f60bfa 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -179,6 +179,7 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::gates::base_sum::BaseSumGate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] fn low_degree() { @@ -187,6 +188,9 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(BaseSumGate::<6>::new(11)) + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(BaseSumGate::<6>::new(11)) } } diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 988086d0..a47145e8 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -430,13 +430,13 @@ mod tests { use anyhow::Result; use rand::Rng; - use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::{Field, PrimeField}; use crate::field::goldilocks_field::GoldilocksField; use crate::gates::comparison::ComparisonGate; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::vars::EvaluationVars; #[test] @@ -478,15 +478,19 @@ mod tests { fn eval_fns() -> Result<()> { let num_bits = 40; let num_chunks = 5; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; - test_eval_fns::(ComparisonGate::<_, 4>::new(num_bits, num_chunks)) + test_eval_fns::(ComparisonGate::<_, 2>::new(num_bits, num_chunks)) } #[test] fn test_gate_constraint() { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; let num_bits = 40; let num_chunks = 5; diff --git a/src/gates/constant.rs b/src/gates/constant.rs index ff8ec851..95123be3 100644 --- a/src/gates/constant.rs +++ b/src/gates/constant.rs @@ -120,6 +120,7 @@ mod tests { use crate::gates::constant::ConstantGate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] fn low_degree() { @@ -130,8 +131,11 @@ mod tests { #[test] fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let num_consts = CircuitConfig::standard_recursion_config().constant_gate_size; let gate = ConstantGate { num_consts }; - test_eval_fns::(gate) + test_eval_fns::(gate) } } diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index 5087cebd..66ae2f0b 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -266,7 +266,6 @@ mod tests { use anyhow::Result; use rand::Rng; - use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::gates::exponentiation::ExponentiationGate; @@ -274,6 +273,7 @@ mod tests { use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::hash::hash_types::HashOut; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::vars::EvaluationVars; use crate::util::log2_ceil; @@ -307,16 +307,20 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(ExponentiationGate::new_from_config( + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(ExponentiationGate::new_from_config( &CircuitConfig::standard_recursion_config(), )) } #[test] fn test_gate_constraint() { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; /// Returns the local wires for an exponentiation gate given the base, power, and power bit /// values. diff --git a/src/gates/gate_testing.rs b/src/gates/gate_testing.rs index 9fe8e835..019d0204 100644 --- a/src/gates/gate_testing.rs +++ b/src/gates/gate_testing.rs @@ -7,6 +7,7 @@ use crate::hash::hash_types::HashOut; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::config::GenericConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::plonk::verifier::verify; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; @@ -84,7 +85,12 @@ fn random_low_degree_values(rate_bits: usize) -> Vec { .values } -pub(crate) fn test_eval_fns, G: Gate, const D: usize>( +pub(crate) fn test_eval_fns< + F: Extendable, + C: GenericConfig, + G: Gate, + const D: usize, +>( gate: G, ) -> Result<()> { // Test that `eval_unfiltered` and `eval_unfiltered_base` are coherent. @@ -151,7 +157,7 @@ pub(crate) fn test_eval_fns, G: Gate, const D let evals_t = gate.eval_unfiltered_recursively(&mut builder, vars_t); pw.set_extension_targets(&evals_t, &evals); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) } diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index 83a7e2fe..edcae822 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -229,12 +229,14 @@ mod tests { use crate::gates::gmimc::GMiMCGate; use crate::gates::interpolation::InterpolationGate; use crate::gates::noop::NoopGate; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] fn test_prefix_generation() { env_logger::init(); - type F = GoldilocksField; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let gates = vec![ GateRef::new(NoopGate), diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 8a12df54..7cdd5fe2 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -342,18 +342,21 @@ mod tests { use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] fn generated_output() { - type F = GoldilocksField; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; const WIDTH: usize = 12; let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - type Gate = GMiMCGate; + type Gate = GMiMCGate; let gate = Gate::new(); let gate_index = builder.add_gate(gate, vec![]); - let circuit = builder.build_prover(); + let circuit = builder.build_prover::(); let permutation_inputs = (0..WIDTH).map(F::from_canonical_usize).collect::>(); @@ -398,9 +401,11 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - type F = GoldilocksField; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; const WIDTH: usize = 12; - let gate = GMiMCGate::::new(); - test_eval_fns(gate) + let gate = GMiMCGate::::new(); + test_eval_fns::(gate) } } diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index dcc79f05..5d49bf5c 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -324,13 +324,13 @@ mod tests { use anyhow::Result; - use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::insertion::InsertionGate; use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::vars::EvaluationVars; #[test] @@ -359,14 +359,18 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(InsertionGate::new(4)) + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(InsertionGate::new(4)) } #[test] fn test_gate_constraint() { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; /// Returns the local wires for an insertion gate given the original vector, element to /// insert, and index. diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 24b755d0..dd672067 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -302,6 +302,7 @@ mod tests { use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::interpolation::InterpolationGate; use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::vars::EvaluationVars; use crate::polynomial::polynomial::PolynomialCoeffs; @@ -332,14 +333,18 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(InterpolationGate::new(4)) + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(InterpolationGate::new(4)) } #[test] fn test_gate_constraint() { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. fn get_wires( diff --git a/src/gates/noop.rs b/src/gates/noop.rs index e615366b..a7851c9d 100644 --- a/src/gates/noop.rs +++ b/src/gates/noop.rs @@ -60,6 +60,7 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::noop::NoopGate; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] fn low_degree() { @@ -68,6 +69,9 @@ mod tests { #[test] fn eval_fns() -> anyhow::Result<()> { - test_eval_fns::(NoopGate) + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(NoopGate) } } diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 1f5f746d..fc972c83 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -6,6 +6,7 @@ use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::gates::gate::Gate; use crate::gates::poseidon_mds::PoseidonMdsGate; +use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::poseidon; use crate::hash::poseidon::Poseidon; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; @@ -21,21 +22,11 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// It has a flag which can be used to swap the first four inputs with the next four, for ordering /// sibling digests. #[derive(Debug)] -pub struct PoseidonGate< - F: RichField + Extendable + Poseidon, - const D: usize, - const WIDTH: usize, -> where - [(); WIDTH - 1]: , -{ +pub struct PoseidonGate, const D: usize> { _phantom: PhantomData, } -impl + Poseidon, const D: usize, const WIDTH: usize> - PoseidonGate -where - [(); WIDTH - 1]: , -{ +impl, const D: usize> PoseidonGate { pub fn new() -> Self { PoseidonGate { _phantom: PhantomData, @@ -49,52 +40,51 @@ where /// The wire index for the `i`th output to the permutation. pub fn wire_output(i: usize) -> usize { - WIDTH + i + SPONGE_WIDTH + i } /// If this is set to 1, the first four inputs will be swapped with the next four inputs. This /// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0. - pub const WIRE_SWAP: usize = 2 * WIDTH; + pub const WIRE_SWAP: usize = 2 * SPONGE_WIDTH; /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set /// of full rounds. fn wire_full_sbox_0(round: usize, i: usize) -> usize { debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < WIDTH); - 2 * WIDTH + 1 + WIDTH * round + i + debug_assert!(i < SPONGE_WIDTH); + 2 * SPONGE_WIDTH + 1 + SPONGE_WIDTH * round + i } /// A wire which stores the input of the S-box of the `round`-th round of the partial rounds. fn wire_partial_sbox(round: usize) -> usize { debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); - 2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round + 2 * SPONGE_WIDTH + 1 + SPONGE_WIDTH * poseidon::HALF_N_FULL_ROUNDS + round } /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set /// of full rounds. fn wire_full_sbox_1(round: usize, i: usize) -> usize { debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < WIDTH); - 2 * WIDTH + debug_assert!(i < SPONGE_WIDTH); + 2 * SPONGE_WIDTH + 1 - + WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round) + + SPONGE_WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round) + poseidon::N_PARTIAL_ROUNDS + i } /// End of wire indices, exclusive. fn end() -> usize { - 2 * WIDTH + 1 + WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + 2 * SPONGE_WIDTH + + 1 + + SPONGE_WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + + poseidon::N_PARTIAL_ROUNDS } } -impl + Poseidon, const D: usize, const WIDTH: usize> Gate - for PoseidonGate -where - [(); WIDTH - 1]: , -{ +impl, const D: usize> Gate for PoseidonGate { fn id(&self) -> String { - format!("{:?}", self, WIDTH) + format!("{:?}", self, SPONGE_WIDTH) } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { @@ -104,7 +94,7 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * (swap - F::Extension::ONE)); - let mut state = Vec::with_capacity(WIDTH); + let mut state = Vec::with_capacity(SPONGE_WIDTH); for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; @@ -115,61 +105,58 @@ where let b = vars.local_wires[i]; state.push(a + swap * (b - a)); } - for i in 8..WIDTH { + for i in 8..SPONGE_WIDTH { state.push(vars.local_wires[i]); } - let mut state: [F::Extension; WIDTH] = state.try_into().unwrap(); + let mut state: [F::Extension; SPONGE_WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer_field(&mut state, round_ctr); - for i in 0..WIDTH { + ::constant_layer_field(&mut state, round_ctr); + for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - >::sbox_layer_field(&mut state); - state = >::mds_layer_field(&state); + ::sbox_layer_field(&mut state); + state = ::mds_layer_field(&state); round_ctr += 1; } // Partial rounds. - >::partial_first_constant_layer(&mut state); - state = >::mds_partial_layer_init(&mut state); + ::partial_first_constant_layer(&mut state); + state = ::mds_partial_layer_init(&mut state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(state[0] - sbox_in); - state[0] = >::sbox_monomial(sbox_in); - state[0] += F::Extension::from_canonical_u64( - >::FAST_PARTIAL_ROUND_CONSTANTS[r], - ); - state = >::mds_partial_layer_fast_field(&state, r); + state[0] = ::sbox_monomial(sbox_in); + state[0] += + F::Extension::from_canonical_u64(::FAST_PARTIAL_ROUND_CONSTANTS[r]); + state = ::mds_partial_layer_fast_field(&state, r); } let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; constraints.push(state[0] - sbox_in); - state[0] = >::sbox_monomial(sbox_in); - state = >::mds_partial_layer_fast_field( - &state, - poseidon::N_PARTIAL_ROUNDS - 1, - ); + state[0] = ::sbox_monomial(sbox_in); + state = + ::mds_partial_layer_fast_field(&state, poseidon::N_PARTIAL_ROUNDS - 1); round_ctr += poseidon::N_PARTIAL_ROUNDS; // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer_field(&mut state, round_ctr); - for i in 0..WIDTH { + ::constant_layer_field(&mut state, round_ctr); + for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - >::sbox_layer_field(&mut state); - state = >::mds_layer_field(&state); + ::sbox_layer_field(&mut state); + state = ::mds_layer_field(&state); round_ctr += 1; } - for i in 0..WIDTH { + for i in 0..SPONGE_WIDTH { constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); } @@ -183,7 +170,7 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * swap.sub_one()); - let mut state = Vec::with_capacity(WIDTH); + let mut state = Vec::with_capacity(SPONGE_WIDTH); for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; @@ -194,58 +181,56 @@ where let b = vars.local_wires[i]; state.push(a + swap * (b - a)); } - for i in 8..WIDTH { + for i in 8..SPONGE_WIDTH { state.push(vars.local_wires[i]); } - let mut state: [F; WIDTH] = state.try_into().unwrap(); + let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); - for i in 0..WIDTH { + ::constant_layer(&mut state, round_ctr); + for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - >::sbox_layer(&mut state); - state = >::mds_layer(&state); + ::sbox_layer(&mut state); + state = ::mds_layer(&state); round_ctr += 1; } // Partial rounds. - >::partial_first_constant_layer(&mut state); - state = >::mds_partial_layer_init(&mut state); + ::partial_first_constant_layer(&mut state); + state = ::mds_partial_layer_init(&mut state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(state[0] - sbox_in); - state[0] = >::sbox_monomial(sbox_in); - state[0] += - F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]); - state = >::mds_partial_layer_fast(&state, r); + state[0] = ::sbox_monomial(sbox_in); + state[0] += F::from_canonical_u64(::FAST_PARTIAL_ROUND_CONSTANTS[r]); + state = ::mds_partial_layer_fast(&state, r); } let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; constraints.push(state[0] - sbox_in); - state[0] = >::sbox_monomial(sbox_in); - state = - >::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1); + state[0] = ::sbox_monomial(sbox_in); + state = ::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1); round_ctr += poseidon::N_PARTIAL_ROUNDS; // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); - for i in 0..WIDTH { + ::constant_layer(&mut state, round_ctr); + for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - >::sbox_layer(&mut state); - state = >::mds_layer(&state); + ::sbox_layer(&mut state); + state = ::mds_layer(&state); round_ctr += 1; } - for i in 0..WIDTH { + for i in 0..SPONGE_WIDTH { constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); } @@ -259,7 +244,7 @@ where ) -> Vec> { // The naive method is more efficient if we have enough routed wires for PoseidonMdsGate. let use_mds_gate = - builder.config.num_routed_wires >= PoseidonMdsGate::::new().num_wires(); + builder.config.num_routed_wires >= PoseidonMdsGate::::new().num_wires(); let mut constraints = Vec::with_capacity(self.num_constraints()); @@ -267,7 +252,7 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(builder.mul_sub_extension(swap, swap, swap)); - let mut state = Vec::with_capacity(WIDTH); + let mut state = Vec::with_capacity(SPONGE_WIDTH); // We need to compute both `if swap {b} else {a}` and `if swap {a} else {b}`. // We will arithmetize them as // swap (b - a) + a @@ -285,54 +270,53 @@ where state.extend(state_first_4); state.extend(state_next_4); - for i in 8..WIDTH { + for i in 8..SPONGE_WIDTH { state.push(vars.local_wires[i]); } - let mut state: [ExtensionTarget; WIDTH] = state.try_into().unwrap(); + let mut state: [ExtensionTarget; SPONGE_WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer_recursive(builder, &mut state, round_ctr); - for i in 0..WIDTH { + ::constant_layer_recursive(builder, &mut state, round_ctr); + for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; constraints.push(builder.sub_extension(state[i], sbox_in)); state[i] = sbox_in; } - >::sbox_layer_recursive(builder, &mut state); - state = >::mds_layer_recursive(builder, &state); + ::sbox_layer_recursive(builder, &mut state); + state = ::mds_layer_recursive(builder, &state); round_ctr += 1; } // Partial rounds. if use_mds_gate { for r in 0..poseidon::N_PARTIAL_ROUNDS { - >::constant_layer_recursive(builder, &mut state, round_ctr); + ::constant_layer_recursive(builder, &mut state, round_ctr); let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(builder.sub_extension(state[0], sbox_in)); - state[0] = >::sbox_monomial_recursive(builder, sbox_in); - state = >::mds_layer_recursive(builder, &state); + state[0] = ::sbox_monomial_recursive(builder, sbox_in); + state = ::mds_layer_recursive(builder, &state); round_ctr += 1; } } else { - >::partial_first_constant_layer_recursive(builder, &mut state); - state = >::mds_partial_layer_init_recursive(builder, &mut state); + ::partial_first_constant_layer_recursive(builder, &mut state); + state = ::mds_partial_layer_init_recursive(builder, &mut state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(builder.sub_extension(state[0], sbox_in)); - state[0] = >::sbox_monomial_recursive(builder, sbox_in); + state[0] = ::sbox_monomial_recursive(builder, sbox_in); state[0] = builder.add_const_extension( state[0], - F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]), + F::from_canonical_u64(::FAST_PARTIAL_ROUND_CONSTANTS[r]), ); - state = - >::mds_partial_layer_fast_recursive(builder, &state, r); + state = ::mds_partial_layer_fast_recursive(builder, &state, r); } let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; constraints.push(builder.sub_extension(state[0], sbox_in)); - state[0] = >::sbox_monomial_recursive(builder, sbox_in); - state = >::mds_partial_layer_fast_recursive( + state[0] = ::sbox_monomial_recursive(builder, sbox_in); + state = ::mds_partial_layer_fast_recursive( builder, &state, poseidon::N_PARTIAL_ROUNDS - 1, @@ -342,18 +326,18 @@ where // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer_recursive(builder, &mut state, round_ctr); - for i in 0..WIDTH { + ::constant_layer_recursive(builder, &mut state, round_ctr); + for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; constraints.push(builder.sub_extension(state[i], sbox_in)); state[i] = sbox_in; } - >::sbox_layer_recursive(builder, &mut state); - state = >::mds_layer_recursive(builder, &state); + ::sbox_layer_recursive(builder, &mut state); + state = ::mds_layer_recursive(builder, &state); round_ctr += 1; } - for i in 0..WIDTH { + for i in 0..SPONGE_WIDTH { constraints .push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)])); } @@ -366,7 +350,7 @@ where gate_index: usize, _local_constants: &[F], ) -> Vec>> { - let gen = PoseidonGenerator:: { + let gen = PoseidonGenerator:: { gate_index, _phantom: PhantomData, }; @@ -386,31 +370,23 @@ where } fn num_constraints(&self) -> usize { - WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + SPONGE_WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + SPONGE_WIDTH + 1 } } #[derive(Debug)] -struct PoseidonGenerator< - F: RichField + Extendable + Poseidon, - const D: usize, - const WIDTH: usize, -> where - [(); WIDTH - 1]: , -{ +struct PoseidonGenerator + Poseidon, const D: usize> { gate_index: usize, _phantom: PhantomData, } -impl + Poseidon, const D: usize, const WIDTH: usize> - SimpleGenerator for PoseidonGenerator -where - [(); WIDTH - 1]: , +impl + Poseidon, const D: usize> SimpleGenerator + for PoseidonGenerator { fn dependencies(&self) -> Vec { - (0..WIDTH) - .map(|i| PoseidonGate::::wire_input(i)) - .chain(Some(PoseidonGate::::WIRE_SWAP)) + (0..SPONGE_WIDTH) + .map(|i| PoseidonGate::::wire_input(i)) + .chain(Some(PoseidonGate::::WIRE_SWAP)) .map(|input| Target::wire(self.gate_index, input)) .collect() } @@ -421,18 +397,18 @@ where input, }; - let mut state = (0..WIDTH) + let mut state = (0..SPONGE_WIDTH) .map(|i| { witness.get_wire(Wire { gate: self.gate_index, - input: PoseidonGate::::wire_input(i), + input: PoseidonGate::::wire_input(i), }) }) .collect::>(); let swap_value = witness.get_wire(Wire { gate: self.gate_index, - input: PoseidonGate::::WIRE_SWAP, + input: PoseidonGate::::WIRE_SWAP, }); debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); if swap_value == F::ONE { @@ -441,65 +417,59 @@ where } } - let mut state: [F; WIDTH] = state.try_into().unwrap(); + let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer_field(&mut state, round_ctr); - for i in 0..WIDTH { + ::constant_layer_field(&mut state, round_ctr); + for i in 0..SPONGE_WIDTH { out_buffer.set_wire( - local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), + local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), state[i], ); } - >::sbox_layer_field(&mut state); - state = >::mds_layer_field(&state); + ::sbox_layer_field(&mut state); + state = ::mds_layer_field(&state); round_ctr += 1; } - >::partial_first_constant_layer(&mut state); - state = >::mds_partial_layer_init(&mut state); + ::partial_first_constant_layer(&mut state); + state = ::mds_partial_layer_init(&mut state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { out_buffer.set_wire( - local_wire(PoseidonGate::::wire_partial_sbox(r)), + local_wire(PoseidonGate::::wire_partial_sbox(r)), state[0], ); - state[0] = >::sbox_monomial(state[0]); - state[0] += - F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]); - state = >::mds_partial_layer_fast_field(&state, r); + state[0] = ::sbox_monomial(state[0]); + state[0] += F::from_canonical_u64(::FAST_PARTIAL_ROUND_CONSTANTS[r]); + state = ::mds_partial_layer_fast_field(&state, r); } out_buffer.set_wire( - local_wire(PoseidonGate::::wire_partial_sbox( + local_wire(PoseidonGate::::wire_partial_sbox( poseidon::N_PARTIAL_ROUNDS - 1, )), state[0], ); - state[0] = >::sbox_monomial(state[0]); - state = >::mds_partial_layer_fast_field( - &state, - poseidon::N_PARTIAL_ROUNDS - 1, - ); + state[0] = ::sbox_monomial(state[0]); + state = + ::mds_partial_layer_fast_field(&state, poseidon::N_PARTIAL_ROUNDS - 1); round_ctr += poseidon::N_PARTIAL_ROUNDS; for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer_field(&mut state, round_ctr); - for i in 0..WIDTH { + ::constant_layer_field(&mut state, round_ctr); + for i in 0..SPONGE_WIDTH { out_buffer.set_wire( - local_wire(PoseidonGate::::wire_full_sbox_1(r, i)), + local_wire(PoseidonGate::::wire_full_sbox_1(r, i)), state[i], ); } - >::sbox_layer_field(&mut state); - state = >::mds_layer_field(&state); + ::sbox_layer_field(&mut state); + state = ::mds_layer_field(&state); round_ctr += 1; } - for i in 0..WIDTH { - out_buffer.set_wire( - local_wire(PoseidonGate::::wire_output(i)), - state[i], - ); + for i in 0..SPONGE_WIDTH { + out_buffer.set_wire(local_wire(PoseidonGate::::wire_output(i)), state[i]); } } } @@ -521,10 +491,13 @@ mod tests { use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] fn generated_output() { - type F = GoldilocksField; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; const WIDTH: usize = 12; let config = CircuitConfig { @@ -532,10 +505,10 @@ mod tests { ..CircuitConfig::standard_recursion_config() }; let mut builder = CircuitBuilder::new(config); - type Gate = PoseidonGate; + type Gate = PoseidonGate; let gate = Gate::new(); let gate_index = builder.add_gate(gate, vec![]); - let circuit = builder.build_prover(); + let circuit = builder.build_prover::(); let permutation_inputs = (0..WIDTH).map(F::from_canonical_usize).collect::>(); @@ -572,14 +545,16 @@ mod tests { #[test] fn low_degree() { type F = GoldilocksField; - let gate = PoseidonGate::::new(); + let gate = PoseidonGate::::new(); test_low_degree(gate) } #[test] fn eval_fns() -> Result<()> { - type F = GoldilocksField; - let gate = PoseidonGate::::new(); - test_eval_fns(gate) + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let gate = PoseidonGate::::new(); + test_eval_fns::(gate) } } diff --git a/src/gates/poseidon_mds.rs b/src/gates/poseidon_mds.rs index 8a42b588..f75948c8 100644 --- a/src/gates/poseidon_mds.rs +++ b/src/gates/poseidon_mds.rs @@ -8,6 +8,7 @@ use crate::field::extension_field::Extendable; use crate::field::extension_field::FieldExtension; use crate::field::field_types::{Field, RichField}; use crate::gates::gate::Gate; +use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::poseidon::Poseidon; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; @@ -16,21 +17,11 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; #[derive(Debug)] -pub struct PoseidonMdsGate< - F: RichField + Extendable + Poseidon, - const D: usize, - const WIDTH: usize, -> where - [(); WIDTH - 1]: , -{ +pub struct PoseidonMdsGate + Poseidon, const D: usize> { _phantom: PhantomData, } -impl + Poseidon, const D: usize, const WIDTH: usize> - PoseidonMdsGate -where - [(); WIDTH - 1]: , -{ +impl + Poseidon, const D: usize> PoseidonMdsGate { pub fn new() -> Self { PoseidonMdsGate { _phantom: PhantomData, @@ -38,13 +29,13 @@ where } pub fn wires_input(i: usize) -> Range { - assert!(i < WIDTH); + assert!(i < SPONGE_WIDTH); i * D..(i + 1) * D } pub fn wires_output(i: usize) -> Range { - assert!(i < WIDTH); - (WIDTH + i) * D..(WIDTH + i + 1) * D + assert!(i < SPONGE_WIDTH); + (SPONGE_WIDTH + i) * D..(SPONGE_WIDTH + i + 1) * D } // Following are methods analogous to ones in `Poseidon`, but for extension algebras. @@ -52,15 +43,14 @@ where /// Same as `mds_row_shf` for an extension algebra of `F`. fn mds_row_shf_algebra( r: usize, - v: &[ExtensionAlgebra; WIDTH], + v: &[ExtensionAlgebra; SPONGE_WIDTH], ) -> ExtensionAlgebra { - debug_assert!(r < WIDTH); + debug_assert!(r < SPONGE_WIDTH); let mut res = ExtensionAlgebra::ZERO; - for i in 0..WIDTH { - let coeff = - F::Extension::from_canonical_u64(1 << >::MDS_MATRIX_EXPS[i]); - res += v[(i + r) % WIDTH].scalar_mul(coeff); + for i in 0..SPONGE_WIDTH { + let coeff = F::Extension::from_canonical_u64(1 << ::MDS_MATRIX_EXPS[i]); + res += v[(i + r) % SPONGE_WIDTH].scalar_mul(coeff); } res @@ -70,16 +60,16 @@ where fn mds_row_shf_algebra_recursive( builder: &mut CircuitBuilder, r: usize, - v: &[ExtensionAlgebraTarget; WIDTH], + v: &[ExtensionAlgebraTarget; SPONGE_WIDTH], ) -> ExtensionAlgebraTarget { - debug_assert!(r < WIDTH); + debug_assert!(r < SPONGE_WIDTH); let mut res = builder.zero_ext_algebra(); - for i in 0..WIDTH { + for i in 0..SPONGE_WIDTH { let coeff = builder.constant_extension(F::Extension::from_canonical_u64( - 1 << >::MDS_MATRIX_EXPS[i], + 1 << ::MDS_MATRIX_EXPS[i], )); - res = builder.scalar_mul_add_ext_algebra(coeff, v[(i + r) % WIDTH], res); + res = builder.scalar_mul_add_ext_algebra(coeff, v[(i + r) % SPONGE_WIDTH], res); } res @@ -87,11 +77,11 @@ where /// Same as `mds_layer` for an extension algebra of `F`. fn mds_layer_algebra( - state: &[ExtensionAlgebra; WIDTH], - ) -> [ExtensionAlgebra; WIDTH] { - let mut result = [ExtensionAlgebra::ZERO; WIDTH]; + state: &[ExtensionAlgebra; SPONGE_WIDTH], + ) -> [ExtensionAlgebra; SPONGE_WIDTH] { + let mut result = [ExtensionAlgebra::ZERO; SPONGE_WIDTH]; - for r in 0..WIDTH { + for r in 0..SPONGE_WIDTH { result[r] = Self::mds_row_shf_algebra(r, state); } @@ -101,11 +91,11 @@ where /// Same as `mds_layer_recursive` for an extension algebra of `F`. fn mds_layer_algebra_recursive( builder: &mut CircuitBuilder, - state: &[ExtensionAlgebraTarget; WIDTH], - ) -> [ExtensionAlgebraTarget; WIDTH] { - let mut result = [builder.zero_ext_algebra(); WIDTH]; + state: &[ExtensionAlgebraTarget; SPONGE_WIDTH], + ) -> [ExtensionAlgebraTarget; SPONGE_WIDTH] { + let mut result = [builder.zero_ext_algebra(); SPONGE_WIDTH]; - for r in 0..WIDTH { + for r in 0..SPONGE_WIDTH { result[r] = Self::mds_row_shf_algebra_recursive(builder, r, state); } @@ -113,17 +103,13 @@ where } } -impl + Poseidon, const D: usize, const WIDTH: usize> Gate - for PoseidonMdsGate -where - [(); WIDTH - 1]: , -{ +impl + Poseidon, const D: usize> Gate for PoseidonMdsGate { fn id(&self) -> String { - format!("{:?}", self, WIDTH) + format!("{:?}", self, SPONGE_WIDTH) } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let inputs: [_; WIDTH] = (0..WIDTH) + let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) .map(|i| vars.get_local_ext_algebra(Self::wires_input(i))) .collect::>() .try_into() @@ -131,7 +117,7 @@ where let computed_outputs = Self::mds_layer_algebra(&inputs); - (0..WIDTH) + (0..SPONGE_WIDTH) .map(|i| vars.get_local_ext_algebra(Self::wires_output(i))) .zip(computed_outputs) .flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array()) @@ -139,7 +125,7 @@ where } fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - let inputs: [_; WIDTH] = (0..WIDTH) + let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) .map(|i| vars.get_local_ext(Self::wires_input(i))) .collect::>() .try_into() @@ -147,7 +133,7 @@ where let computed_outputs = F::mds_layer_field(&inputs); - (0..WIDTH) + (0..SPONGE_WIDTH) .map(|i| vars.get_local_ext(Self::wires_output(i))) .zip(computed_outputs) .flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array()) @@ -159,7 +145,7 @@ where builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { - let inputs: [_; WIDTH] = (0..WIDTH) + let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) .map(|i| vars.get_local_ext_algebra(Self::wires_input(i))) .collect::>() .try_into() @@ -167,7 +153,7 @@ where let computed_outputs = Self::mds_layer_algebra_recursive(builder, &inputs); - (0..WIDTH) + (0..SPONGE_WIDTH) .map(|i| vars.get_local_ext_algebra(Self::wires_output(i))) .zip(computed_outputs) .flat_map(|(out, computed_out)| { @@ -183,12 +169,12 @@ where gate_index: usize, _local_constants: &[F], ) -> Vec>> { - let gen = PoseidonMdsGenerator:: { gate_index }; + let gen = PoseidonMdsGenerator:: { gate_index }; vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { - 2 * D * WIDTH + 2 * D * SPONGE_WIDTH } fn num_constants(&self) -> usize { @@ -200,30 +186,22 @@ where } fn num_constraints(&self) -> usize { - WIDTH * D + SPONGE_WIDTH * D } } #[derive(Clone, Debug)] -struct PoseidonMdsGenerator -where - [(); WIDTH - 1]: , -{ +struct PoseidonMdsGenerator { gate_index: usize, } -impl + Poseidon, const D: usize, const WIDTH: usize> - SimpleGenerator for PoseidonMdsGenerator -where - [(); WIDTH - 1]: , +impl + Poseidon, const D: usize> SimpleGenerator + for PoseidonMdsGenerator { fn dependencies(&self) -> Vec { - (0..WIDTH) + (0..SPONGE_WIDTH) .flat_map(|i| { - Target::wires_from_range( - self.gate_index, - PoseidonMdsGate::::wires_input(i), - ) + Target::wires_from_range(self.gate_index, PoseidonMdsGate::::wires_input(i)) }) .collect() } @@ -234,8 +212,8 @@ where let get_local_ext = |wire_range| witness.get_extension_target(get_local_get_target(wire_range)); - let inputs: [_; WIDTH] = (0..WIDTH) - .map(|i| get_local_ext(PoseidonMdsGate::::wires_input(i))) + let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) + .map(|i| get_local_ext(PoseidonMdsGate::::wires_input(i))) .collect::>() .try_into() .unwrap(); @@ -244,7 +222,7 @@ where for (i, &out) in outputs.iter().enumerate() { out_buffer.set_extension_target( - get_local_get_target(PoseidonMdsGate::::wires_output(i)), + get_local_get_target(PoseidonMdsGate::::wires_output(i)), out, ); } @@ -257,18 +235,21 @@ mod tests { use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::poseidon_mds::PoseidonMdsGate; use crate::hash::hashing::SPONGE_WIDTH; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] fn low_degree() { type F = GoldilocksField; - let gate = PoseidonMdsGate::::new(); + let gate = PoseidonMdsGate::::new(); test_low_degree(gate) } #[test] fn eval_fns() -> anyhow::Result<()> { - type F = GoldilocksField; - let gate = PoseidonMdsGate::::new(); - test_eval_fns(gate) + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let gate = PoseidonMdsGate::::new(); + test_eval_fns::(gate) } } diff --git a/src/gates/public_input.rs b/src/gates/public_input.rs index 4001cf2a..f06df063 100644 --- a/src/gates/public_input.rs +++ b/src/gates/public_input.rs @@ -80,6 +80,7 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::public_input::PublicInputGate; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] fn low_degree() { @@ -88,6 +89,9 @@ mod tests { #[test] fn eval_fns() -> anyhow::Result<()> { - test_eval_fns::(PublicInputGate) + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(PublicInputGate) } } diff --git a/src/gates/random_access.rs b/src/gates/random_access.rs index bdbff667..452fcf35 100644 --- a/src/gates/random_access.rs +++ b/src/gates/random_access.rs @@ -299,6 +299,7 @@ mod tests { use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::random_access::RandomAccessGate; use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::vars::EvaluationVars; #[test] @@ -308,14 +309,18 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(RandomAccessGate::new(4, 4)) + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(RandomAccessGate::new(4, 4)) } #[test] fn test_gate_constraint() { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; /// Returns the local wires for a random access gate given the vectors, elements to compare, /// and indices. diff --git a/src/gates/reducing.rs b/src/gates/reducing.rs index 56d8f590..c9ffce57 100644 --- a/src/gates/reducing.rs +++ b/src/gates/reducing.rs @@ -216,6 +216,7 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::reducing::ReducingGate; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] fn low_degree() { @@ -224,6 +225,9 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(ReducingGate::new(22)) + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(ReducingGate::new(22)) } } diff --git a/src/gates/switch.rs b/src/gates/switch.rs index a1abd696..78babb00 100644 --- a/src/gates/switch.rs +++ b/src/gates/switch.rs @@ -312,7 +312,6 @@ mod tests { use anyhow::Result; - use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::gates::gate::Gate; @@ -320,6 +319,7 @@ mod tests { use crate::gates::switch::SwitchGate; use crate::hash::hash_types::HashOut; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::vars::EvaluationVars; #[test] @@ -359,7 +359,10 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(SwitchGate::<_, 4>::new_from_config( + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(SwitchGate::<_, D>::new_from_config( &CircuitConfig::standard_recursion_config(), 3, )) @@ -367,9 +370,10 @@ mod tests { #[test] fn test_gate_constraint() { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; const CHUNK_SIZE: usize = 4; let num_copies = 3; diff --git a/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs b/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs index 1df21550..6f257f56 100644 --- a/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs +++ b/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs @@ -47,7 +47,7 @@ const fn check_mds_matrix() -> bool { let mut i = 0; let wanted_matrix_exps = [0, 0, 1, 0, 3, 5, 1, 8, 12, 3, 16, 10]; while i < WIDTH { - if >::MDS_MATRIX_EXPS[i] != wanted_matrix_exps[i] { + if ::MDS_MATRIX_EXPS[i] != wanted_matrix_exps[i] { return false; } i += 1; @@ -62,7 +62,7 @@ const fn mds_matrix_inf_norm() -> u64 { let mut cumul = 0; let mut i = 0; while i < WIDTH { - cumul += 1 << >::MDS_MATRIX_EXPS[i]; + cumul += 1 << ::MDS_MATRIX_EXPS[i]; i += 1; } cumul diff --git a/src/hash/hash_types.rs b/src/hash/hash_types.rs index eb2f16b0..e0f95c88 100644 --- a/src/hash/hash_types.rs +++ b/src/hash/hash_types.rs @@ -3,7 +3,7 @@ use std::convert::TryInto; use rand::Rng; use serde::{Deserialize, Serialize}; -use crate::field::field_types::Field; +use crate::field::field_types::{Field, PrimeField}; use crate::iop::target::Target; /// Represents a ~256 bit hash output. @@ -59,6 +59,41 @@ impl Default for HashOut { } } +impl From> for HashOut { + fn from(v: Vec) -> Self { + HashOut { + elements: v + .chunks(8) + .take(4) + .map(|x| F::from_canonical_u64(u64::from_le_bytes(x.try_into().unwrap()))) + .collect::>() + .try_into() + .unwrap(), + } + } +} + +impl From> for Vec { + fn from(h: HashOut) -> Self { + h.elements + .into_iter() + .flat_map(|x| x.to_canonical_u64().to_le_bytes()) + .collect() + } +} + +impl From> for Vec { + fn from(h: HashOut) -> Self { + h.elements.to_vec() + } +} + +impl From> for u64 { + fn from(h: HashOut) -> Self { + h.elements[0].to_canonical_u64() + } +} + /// Represents a ~256 bit hash output. #[derive(Copy, Clone, Debug)] pub struct HashOutTarget { diff --git a/src/hash/hashing.rs b/src/hash/hashing.rs index d031ebbb..a63106be 100644 --- a/src/hash/hashing.rs +++ b/src/hash/hashing.rs @@ -7,43 +7,48 @@ use crate::field::field_types::RichField; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::config::AlgebraicHasher; pub(crate) const SPONGE_RATE: usize = 8; pub(crate) const SPONGE_CAPACITY: usize = 4; pub const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; -pub(crate) const HASH_FAMILY: HashFamily = HashFamily::Poseidon; - -pub(crate) enum HashFamily { - GMiMC, - Poseidon, -} +// pub(crate) const HASH_FAMILY: HashFamily = HashFamily::Poseidon; +// +// pub(crate) enum HashFamily { +// GMiMC, +// Poseidon, +// } /// Hash the vector if necessary to reduce its length to ~256 bits. If it already fits, this is a /// no-op. -pub fn hash_or_noop(inputs: Vec) -> HashOut { +pub fn hash_or_noop>(inputs: Vec) -> HashOut { if inputs.len() <= 4 { HashOut::from_partial(inputs) } else { - hash_n_to_hash(inputs, false) + hash_n_to_hash::(inputs, false) } } impl, const D: usize> CircuitBuilder { - pub fn hash_or_noop(&mut self, inputs: Vec) -> HashOutTarget { + pub fn hash_or_noop>(&mut self, inputs: Vec) -> HashOutTarget { let zero = self.zero(); if inputs.len() <= 4 { HashOutTarget::from_partial(inputs, zero) } else { - self.hash_n_to_hash(inputs, false) + self.hash_n_to_hash::(inputs, false) } } - pub fn hash_n_to_hash(&mut self, inputs: Vec, pad: bool) -> HashOutTarget { - HashOutTarget::from_vec(self.hash_n_to_m(inputs, 4, pad)) + pub fn hash_n_to_hash>( + &mut self, + inputs: Vec, + pad: bool, + ) -> HashOutTarget { + HashOutTarget::from_vec(self.hash_n_to_m::(inputs, 4, pad)) } - pub fn hash_n_to_m( + pub fn hash_n_to_m>( &mut self, mut inputs: Vec, num_outputs: usize, @@ -68,7 +73,7 @@ impl, const D: usize> CircuitBuilder { // where we would xor or add in the inputs. This is a well-known variant, though, // sometimes called "overwrite mode". state[..input_chunk.len()].copy_from_slice(input_chunk); - state = self.permute(state); + state = self.permute::(state); } // Squeeze until we have the desired number of outputs. @@ -80,25 +85,46 @@ impl, const D: usize> CircuitBuilder { return outputs; } } - state = self.permute(state); + state = self.permute::(state); } } } /// A one-way compression function which takes two ~256 bit inputs and returns a ~256 bit output. -pub fn compress(x: HashOut, y: HashOut) -> HashOut { +pub fn compress>(x: HashOut, y: HashOut) -> HashOut { let mut perm_inputs = [F::ZERO; SPONGE_WIDTH]; perm_inputs[..4].copy_from_slice(&x.elements); perm_inputs[4..8].copy_from_slice(&y.elements); HashOut { - elements: permute(perm_inputs)[..4].try_into().unwrap(), + elements: P::permute(perm_inputs)[..4].try_into().unwrap(), + } +} + +pub trait PlonkyPermutation { + fn permute(input: [F; SPONGE_WIDTH]) -> [F; SPONGE_WIDTH]; +} + +pub struct PoseidonPermutation; +impl PlonkyPermutation for PoseidonPermutation { + fn permute(input: [F; SPONGE_WIDTH]) -> [F; SPONGE_WIDTH] { + F::poseidon(input) + } +} +pub struct GMiMCPermutation; +impl PlonkyPermutation for GMiMCPermutation { + fn permute(input: [F; SPONGE_WIDTH]) -> [F; SPONGE_WIDTH] { + F::gmimc_permute(input) } } /// If `pad` is enabled, the message is padded using the pad10*1 rule. In general this is required /// for the hash to be secure, but it can safely be disabled in certain cases, like if the input /// length is fixed. -pub fn hash_n_to_m(mut inputs: Vec, num_outputs: usize, pad: bool) -> Vec { +pub fn hash_n_to_m>( + mut inputs: Vec, + num_outputs: usize, + pad: bool, +) -> Vec { if pad { inputs.push(F::ZERO); while (inputs.len() + 1) % SPONGE_WIDTH != 0 { @@ -114,7 +140,7 @@ pub fn hash_n_to_m(mut inputs: Vec, num_outputs: usize, pad: bo for i in 0..input_chunk.len() { state[i] = input_chunk[i]; } - state = permute(state); + state = P::permute(state); } // Squeeze until we have the desired number of outputs. @@ -126,21 +152,17 @@ pub fn hash_n_to_m(mut inputs: Vec, num_outputs: usize, pad: bo return outputs; } } - state = permute(state); + state = P::permute(state); } } -pub fn hash_n_to_hash(inputs: Vec, pad: bool) -> HashOut { - HashOut::from_vec(hash_n_to_m(inputs, 4, pad)) +pub fn hash_n_to_hash>( + inputs: Vec, + pad: bool, +) -> HashOut { + HashOut::from_vec(hash_n_to_m::(inputs, 4, pad)) } -pub fn hash_n_to_1(inputs: Vec, pad: bool) -> F { - hash_n_to_m(inputs, 1, pad)[0] -} - -pub(crate) fn permute(inputs: [F; SPONGE_WIDTH]) -> [F; SPONGE_WIDTH] { - match HASH_FAMILY { - HashFamily::GMiMC => F::gmimc_permute(inputs), - HashFamily::Poseidon => F::poseidon(inputs), - } +pub fn hash_n_to_1>(inputs: Vec, pad: bool) -> F { + hash_n_to_m::(inputs, 1, pad)[0] } diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 1ba93cf0..38105d2c 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -4,18 +4,19 @@ use anyhow::{ensure, Result}; use serde::{Deserialize, Serialize}; use crate::field::extension_field::Extendable; -use crate::field::field_types::{Field, RichField}; -use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; -use crate::hash::hashing::{compress, hash_or_noop, SPONGE_WIDTH}; +use crate::field::field_types::RichField; +use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget}; +use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::config::{AlgebraicHasher, Hasher}; #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(bound = "")] -pub struct MerkleProof { +pub struct MerkleProof> { /// The Merkle digest of each sibling subtree, staying from the bottommost layer. - pub siblings: Vec>, + pub siblings: Vec, } #[derive(Clone)] @@ -26,21 +27,21 @@ pub struct MerkleProofTarget { /// Verifies that the given leaf data is present at the given index in the Merkle tree with the /// given cap. -pub(crate) fn verify_merkle_proof( +pub(crate) fn verify_merkle_proof>( leaf_data: Vec, leaf_index: usize, - merkle_cap: &MerkleCap, - proof: &MerkleProof, + merkle_cap: &MerkleCap, + proof: &MerkleProof, ) -> Result<()> { let mut index = leaf_index; - let mut current_digest = hash_or_noop(leaf_data); + let mut current_digest = H::hash(leaf_data, false); for &sibling_digest in proof.siblings.iter() { let bit = index & 1; index >>= 1; current_digest = if bit == 1 { - compress(sibling_digest, current_digest) + H::two_to_one(sibling_digest, current_digest) } else { - compress(current_digest, sibling_digest) + H::two_to_one(current_digest, sibling_digest) } } ensure!( @@ -54,7 +55,7 @@ pub(crate) fn verify_merkle_proof( impl, const D: usize> CircuitBuilder { /// Verifies that the given leaf data is present at the given index in the Merkle tree with the /// given cap. The index is given by it's little-endian bits. - pub(crate) fn verify_merkle_proof( + pub(crate) fn verify_merkle_proof>( &mut self, leaf_data: Vec, leaf_index_bits: &[BoolTarget], @@ -62,13 +63,13 @@ impl, const D: usize> CircuitBuilder { proof: &MerkleProofTarget, ) { let zero = self.zero(); - let mut state: HashOutTarget = self.hash_or_noop(leaf_data); + let mut state: HashOutTarget = self.hash_or_noop::(leaf_data); for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { let mut perm_inputs = [zero; SPONGE_WIDTH]; perm_inputs[..4].copy_from_slice(&state.elements); perm_inputs[4..8].copy_from_slice(&sibling.elements); - let outputs = self.permute_swapped(perm_inputs, bit); + let outputs = self.permute_swapped::(perm_inputs, bit); state = HashOutTarget::from_vec(outputs[0..4].to_vec()); } @@ -84,7 +85,7 @@ impl, const D: usize> CircuitBuilder { } /// Same as `verify_merkle_proof` but with the final "cap index" as extra parameter. - pub(crate) fn verify_merkle_proof_with_cap_index( + pub(crate) fn verify_merkle_proof_with_cap_index>( &mut self, leaf_data: Vec, leaf_index_bits: &[BoolTarget], @@ -93,13 +94,13 @@ impl, const D: usize> CircuitBuilder { proof: &MerkleProofTarget, ) { let zero = self.zero(); - let mut state: HashOutTarget = self.hash_or_noop(leaf_data); + let mut state: HashOutTarget = self.hash_or_noop::(leaf_data); for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { let mut perm_inputs = [zero; SPONGE_WIDTH]; perm_inputs[..4].copy_from_slice(&state.elements); perm_inputs[4..8].copy_from_slice(&sibling.elements); - let perm_outs = self.permute_swapped(perm_inputs, bit); + let perm_outs = self.permute_swapped::(perm_inputs, bit); let hash_outs = perm_outs[0..4].try_into().unwrap(); state = HashOutTarget { elements: hash_outs, @@ -128,11 +129,13 @@ mod tests { use rand::{thread_rng, Rng}; use super::*; + use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::hash::merkle_tree::MerkleTree; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; fn random_data(n: usize, k: usize) -> Vec> { @@ -141,16 +144,18 @@ mod tests { #[test] fn test_recursive_merkle_proof() -> Result<()> { - type F = GoldilocksField; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let config = CircuitConfig::standard_recursion_config(); let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::::new(config); let log_n = 8; let n = 1 << log_n; let cap_height = 1; let leaves = random_data::(n, 7); - let tree = MerkleTree::new(leaves, cap_height); + let tree = MerkleTree::>::Hasher>::new(leaves, cap_height); let i: usize = thread_rng().gen_range(0..n); let proof = tree.prove(i); @@ -172,9 +177,11 @@ mod tests { pw.set_target(data[j], tree.leaves[i][j]); } - builder.verify_merkle_proof(data, &i_bits, &cap_t, &proof_t); + builder.verify_merkle_proof::<>::InnerHasher>( + data, &i_bits, &cap_t, &proof_t, + ); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) diff --git a/src/hash/merkle_tree.rs b/src/hash/merkle_tree.rs index 2a33a143..cb80fcdc 100644 --- a/src/hash/merkle_tree.rs +++ b/src/hash/merkle_tree.rs @@ -1,44 +1,49 @@ use rayon::prelude::*; use serde::{Deserialize, Serialize}; -use crate::field::field_types::{Field, RichField}; -use crate::hash::hash_types::HashOut; -use crate::hash::hashing::{compress, hash_or_noop}; +use crate::field::field_types::RichField; use crate::hash::merkle_proofs::MerkleProof; +use crate::plonk::config::Hasher; /// The Merkle cap of height `h` of a Merkle tree is the `h`-th layer (from the root) of the tree. /// It can be used in place of the root to verify Merkle paths, which are `h` elements shorter. #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(bound = "")] -pub struct MerkleCap(pub Vec>); +pub struct MerkleCap>(pub Vec); -impl MerkleCap { +impl> MerkleCap { pub fn len(&self) -> usize { self.0.len() } pub fn flatten(&self) -> Vec { - self.0.iter().flat_map(|h| h.elements).collect() + self.0 + .iter() + .flat_map(|&h| { + let felts: Vec = h.into(); + felts + }) + .collect() } } #[derive(Clone, Debug)] -pub struct MerkleTree { +pub struct MerkleTree> { /// The data in the leaves of the Merkle tree. pub leaves: Vec>, /// The layers of hashes in the tree. The first layer is the one at the bottom. - pub layers: Vec>>, + pub layers: Vec>, /// The Merkle cap. - pub cap: MerkleCap, + pub cap: MerkleCap, } -impl MerkleTree { +impl> MerkleTree { pub fn new(leaves: Vec>, cap_height: usize) -> Self { let mut layers = vec![leaves .par_iter() - .map(|l| hash_or_noop(l.clone())) + .map(|l| H::hash(l.clone(), false)) .collect::>()]; while let Some(l) = layers.last() { if l.len() == 1 << cap_height { @@ -46,7 +51,7 @@ impl MerkleTree { } let next_layer = l .par_chunks(2) - .map(|chunk| compress(chunk[0], chunk[1])) + .map(|chunk| H::two_to_one(chunk[0], chunk[1])) .collect::>(); layers.push(next_layer); } @@ -63,7 +68,7 @@ impl MerkleTree { } /// Create a Merkle proof from a leaf index. - pub fn prove(&self, leaf_index: usize) -> MerkleProof { + pub fn prove(&self, leaf_index: usize) -> MerkleProof { MerkleProof { siblings: self .layers @@ -83,15 +88,20 @@ mod tests { use anyhow::Result; use super::*; + use crate::field::extension_field::Extendable; use crate::field::goldilocks_field::GoldilocksField; use crate::hash::merkle_proofs::verify_merkle_proof; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; fn random_data(n: usize, k: usize) -> Vec> { (0..n).map(|_| F::rand_vec(k)).collect() } - fn verify_all_leaves(leaves: Vec>, n: usize) -> Result<()> { - let tree = MerkleTree::new(leaves.clone(), 1); + fn verify_all_leaves, C: GenericConfig, const D: usize>( + leaves: Vec>, + n: usize, + ) -> Result<()> { + let tree = MerkleTree::::new(leaves.clone(), 1); for i in 0..n { let proof = tree.prove(i); verify_merkle_proof(leaves[i].clone(), i, &tree.cap, &proof)?; @@ -101,13 +111,15 @@ mod tests { #[test] fn test_merkle_trees() -> Result<()> { - type F = GoldilocksField; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let log_n = 8; let n = 1 << log_n; let leaves = random_data::(n, 7); - verify_all_leaves(leaves, n)?; + verify_all_leaves::(leaves, n)?; Ok(()) } diff --git a/src/hash/path_compression.rs b/src/hash/path_compression.rs index 8b86baac..a18fbefb 100644 --- a/src/hash/path_compression.rs +++ b/src/hash/path_compression.rs @@ -2,16 +2,16 @@ use std::collections::HashMap; use num::Integer; -use crate::field::field_types::{Field, RichField}; -use crate::hash::hashing::{compress, hash_or_noop}; +use crate::field::field_types::RichField; use crate::hash::merkle_proofs::MerkleProof; +use crate::plonk::config::Hasher; /// Compress multiple Merkle proofs on the same tree by removing redundancy in the Merkle paths. -pub(crate) fn compress_merkle_proofs( +pub(crate) fn compress_merkle_proofs>( cap_height: usize, indices: &[usize], - proofs: &[MerkleProof], -) -> Vec> { + proofs: &[MerkleProof], +) -> Vec> { assert!(!proofs.is_empty()); let height = cap_height + proofs[0].siblings.len(); let num_leaves = 1 << height; @@ -51,13 +51,13 @@ pub(crate) fn compress_merkle_proofs( /// Decompress compressed Merkle proofs. /// Note: The data and indices must be in the same order as in `compress_merkle_proofs`. -pub(crate) fn decompress_merkle_proofs( +pub(crate) fn decompress_merkle_proofs>( leaves_data: &[Vec], leaves_indices: &[usize], - compressed_proofs: &[MerkleProof], + compressed_proofs: &[MerkleProof], height: usize, cap_height: usize, -) -> Vec> { +) -> Vec> { let num_leaves = 1 << height; let compressed_proofs = compressed_proofs.to_vec(); let mut decompressed_proofs = Vec::with_capacity(compressed_proofs.len()); @@ -66,7 +66,7 @@ pub(crate) fn decompress_merkle_proofs( for (&i, v) in leaves_indices.iter().zip(leaves_data) { // Observe the leaves. - seen.insert(i + num_leaves, hash_or_noop(v.to_vec())); + seen.insert(i + num_leaves, H::hash(v.to_vec(), false)); } // Iterators over the siblings. @@ -84,9 +84,9 @@ pub(crate) fn decompress_merkle_proofs( .entry(sibling_index) .or_insert_with(|| *p.next().unwrap()); let parent_hash = if index.is_even() { - compress(current_hash, sibling_hash) + H::two_to_one(current_hash, sibling_hash) } else { - compress(sibling_hash, current_hash) + H::two_to_one(sibling_hash, current_hash) }; seen.insert(index >> 1, parent_hash); } @@ -118,14 +118,17 @@ mod tests { use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::hash::merkle_tree::MerkleTree; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] fn test_path_compression() { - type F = GoldilocksField; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let h = 10; let cap_height = 3; let vs = (0..1 << h).map(|_| vec![F::rand()]).collect::>(); - let mt = MerkleTree::new(vs.clone(), cap_height); + let mt = MerkleTree::>::Hasher>::new(vs.clone(), cap_height); let mut rng = thread_rng(); let k = rng.gen_range(1..=1 << h); diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 9a52060c..b94cc481 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -144,11 +144,8 @@ pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ 0x4543d9df72c4831d, 0xf172d73e69f20739, 0xdfd1c4ff1eb3d868, 0xbc8dfb62d26376f7, ]; -pub trait Poseidon: PrimeField -where - // magic to get const generic expressions to work - [(); WIDTH - 1]: , -{ +const WIDTH: usize = 12; +pub trait Poseidon: PrimeField { // Total number of round constants required: width of the input // times number of rounds. const N_ROUND_CONSTANTS: usize = WIDTH * N_ROUNDS; @@ -216,7 +213,7 @@ where let mut res = builder.zero_extension(); for i in 0..WIDTH { - let c = Self::from_canonical_u64(1 << >::MDS_MATRIX_EXPS[i]); + let c = Self::from_canonical_u64(1 << ::MDS_MATRIX_EXPS[i]); res = builder.mul_const_add_extension(c, v[(i + r) % WIDTH], res); } @@ -269,16 +266,16 @@ where Self: RichField + Extendable, { // If we have enough routed wires, we will use PoseidonMdsGate. - let mds_gate = PoseidonMdsGate::::new(); + let mds_gate = PoseidonMdsGate::::new(); if builder.config.num_routed_wires >= mds_gate.num_wires() { let index = builder.add_gate(mds_gate, vec![]); for i in 0..WIDTH { - let input_wire = PoseidonMdsGate::::wires_input(i); + let input_wire = PoseidonMdsGate::::wires_input(i); builder.connect_extension(state[i], ExtensionTarget::from_range(index, input_wire)); } (0..WIDTH) .map(|i| { - let output_wire = PoseidonMdsGate::::wires_output(i); + let output_wire = PoseidonMdsGate::::wires_output(i); ExtensionTarget::from_range(index, output_wire) }) .collect::>() @@ -316,7 +313,7 @@ where Self: RichField + Extendable, { for i in 0..WIDTH { - let c = >::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]; + let c = ::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]; let c = Self::Extension::from_canonical_u64(c); let c = builder.constant_extension(c); state[i] = builder.add_extension(state[i], c); @@ -369,7 +366,7 @@ where for r in 1..WIDTH { for c in 1..WIDTH { - let t = >::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1]; + let t = ::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1]; let t = Self::Extension::from_canonical_u64(t); let t = builder.constant_extension(t); result[c] = builder.mul_add_extension(t, state[r], result[c]); @@ -450,11 +447,11 @@ where { let s0 = state[0]; let mut d = builder.mul_const_extension( - Self::from_canonical_u64(1 << >::MDS_MATRIX_EXPS[0]), + Self::from_canonical_u64(1 << ::MDS_MATRIX_EXPS[0]), s0, ); for i in 1..WIDTH { - let t = >::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]; + let t = ::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]; let t = Self::from_canonical_u64(t); d = builder.mul_const_add_extension(t, state[i], d); } @@ -462,7 +459,7 @@ where let mut result = [builder.zero_extension(); WIDTH]; result[0] = d; for i in 1..WIDTH { - let t = >::FAST_PARTIAL_ROUND_VS[r][i - 1]; + let t = ::FAST_PARTIAL_ROUND_VS[r][i - 1]; let t = Self::Extension::from_canonical_u64(t); let t = builder.constant_extension(t); result[i] = builder.mul_add_extension(t, state[0], state[i]); @@ -559,7 +556,7 @@ where Self: RichField + Extendable, { for i in 0..WIDTH { - state[i] = >::sbox_monomial_recursive(builder, state[i]); + state[i] = ::sbox_monomial_recursive(builder, state[i]); } } @@ -628,39 +625,38 @@ where pub(crate) mod test_helpers { use crate::field::field_types::Field; + use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::poseidon::Poseidon; - pub(crate) fn check_test_vectors( - test_vectors: Vec<([u64; WIDTH], [u64; WIDTH])>, + pub(crate) fn check_test_vectors( + test_vectors: Vec<([u64; SPONGE_WIDTH], [u64; SPONGE_WIDTH])>, ) where - F: Poseidon, - [(); WIDTH - 1]: , + F: Poseidon, { for (input_, expected_output_) in test_vectors.into_iter() { - let mut input = [F::ZERO; WIDTH]; - for i in 0..WIDTH { + let mut input = [F::ZERO; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { input[i] = F::from_canonical_u64(input_[i]); } let output = F::poseidon(input); - for i in 0..WIDTH { + for i in 0..SPONGE_WIDTH { let ex_output = F::from_canonical_u64(expected_output_[i]); assert_eq!(output[i], ex_output); } } } - pub(crate) fn check_consistency() + pub(crate) fn check_consistency() where - F: Poseidon, - [(); WIDTH - 1]: , + F: Poseidon, { - let mut input = [F::ZERO; WIDTH]; - for i in 0..WIDTH { + let mut input = [F::ZERO; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { input[i] = F::from_canonical_u64(i as u64); } let output = F::poseidon(input); let output_naive = F::poseidon_naive(input); - for i in 0..WIDTH { + for i in 0..SPONGE_WIDTH { assert_eq!(output[i], output_naive[i]); } } diff --git a/src/hash/poseidon_crandall.rs b/src/hash/poseidon_crandall.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/hash/poseidon_crandall.rs @@ -0,0 +1 @@ + diff --git a/src/hash/poseidon_goldilocks.rs b/src/hash/poseidon_goldilocks.rs index 0b1f9a49..743590f0 100644 --- a/src/hash/poseidon_goldilocks.rs +++ b/src/hash/poseidon_goldilocks.rs @@ -8,147 +8,7 @@ use crate::field::goldilocks_field::GoldilocksField; use crate::hash::poseidon::{Poseidon, N_PARTIAL_ROUNDS}; #[rustfmt::skip] -impl Poseidon<8> for GoldilocksField { - // The MDS matrix we use is the circulant matrix with first row given by the vector - // [ 2^x for x in MDS_MATRIX_EXPS] = [1, 1, 2, 1, 8, 32, 4, 256] - // - // WARNING: If the MDS matrix is changed, then the following - // constants need to be updated accordingly: - // - FAST_PARTIAL_ROUND_CONSTANTS - // - FAST_PARTIAL_ROUND_VS - // - FAST_PARTIAL_ROUND_W_HATS - // - FAST_PARTIAL_ROUND_INITIAL_MATRIX - const MDS_MATRIX_EXPS: [u64; 8] = [0, 0, 1, 0, 3, 5, 2, 8]; - - const FAST_PARTIAL_FIRST_ROUND_CONSTANT: [u64; 8] = [ - 0x66bbd30e99d311da, 0xac0494d706139435, 0x7eea5812cb4c5eb2, 0x6061af64681ce880, - 0xfce86220df80ac43, 0x5285da71ebb7b008, 0x8649956f6d44d2a2, 0xcf8c90ab81a0ca0a, - ]; - - const FAST_PARTIAL_ROUND_CONSTANTS: [u64; N_PARTIAL_ROUNDS] = [ - 0xd3e8f03df7f0d35c, 0x3ef0eeeed58f09f7, 0x6b54f9fd0ecdfa58, 0x129f9c79c53051f4, - 0xe0ee72d960a7c705, 0x2dc8a0d0d92c1497, 0x6936412d8980befa, 0x64f44cf4c7211138, - 0xcd28551a527e2472, 0x71c8b45ae08e543e, 0xcbde77e27af5b694, 0xab4d6a7cbb49e2f0, - 0xaaef22c4753df029, 0x4889f5d08dbf0f1f, 0x5fa33b282603eb65, 0x86661e9507022660, - 0x3e31490d4eeb1d9f, 0xc581d1f6d84c6485, 0x77e61c9742a20dd3, 0x9edc0491219ecb5c, - 0x5b846917f2f767eb, 0x0, - ]; - - const FAST_PARTIAL_ROUND_VS: [[u64; 8 - 1]; N_PARTIAL_ROUNDS] = [ - [0xb9af2750293b9624, 0x1148fcc5cbe27c57, 0x174a9735f87d5b66, 0x9ade5dad416cccfa, - 0x191867d7fd58636a, 0x1018a176ac6b8850, 0x6baa69bf6caac2f7, ], - [0x5d3a3be85300d127, 0x602d9345fdb2950b, 0xa71b08e14841259d, 0x8c9e66a88cfc2a2f, - 0xd23f18447b9d6ca6, 0x9c7b63750e75136d, 0xc0036bb483def9f6, ], - [0xd8e171f97120488d, 0x963ace7d45dd3534, 0xe1110876d0920bb1, 0xc2554b2a73562b4d, - 0x25c5559e1da9b854, 0xfd6a3146495a05e8, 0x238d725e9bbea44f, ], - [0xf64bc8099412ee92, 0x43a6897f45dac19e, 0xca7101923a589502, 0x142f002e59b5c266, - 0xf03ceac54cef3438, 0x66b181f8f5003148, 0xa771a1eef052f853, ], - [0x9d4b9376927960be, 0x99543e4c8809ec7d, 0x86b30b2577e74c74, 0x5bc8aeabd7389991, - 0xcb9c2b7e2f4ec665, 0x0de73a3c82e91199, 0x0f2d2370f6bc0228, ], - [0x253dd236fc5e4f15, 0x3ec881b20a588043, 0xbc42663d732126fe, 0xe3e6fa02e77ad144, - 0x04b1e0459ba85bbf, 0x6550e387f467aee7, 0xc34b817494f32dd8, ], - [0xd9423529e3d9b44e, 0x327e2609b24d5a59, 0x9ab352e6581fd735, 0x95a6a4e5dd94aefc, - 0x44f860fc8a140181, 0x10fe3ee72bbaf4bc, 0x41b951dfc4190fe2, ], - [0x931b2f16aae2cb8d, 0xb2cd58604bb14653, 0xe68e709a8bcb1228, 0x286b1cb1bdd94d41, - 0xaf3f0e1f41093ffd, 0xcc00f393df3aef69, 0x68eeb30cca0b90fe, ], - [0xcfbc82fae1248b3c, 0xaea4f7382d6e7d1a, 0xfe46b0ab3d6e3160, 0xa7ee349ec637bfd2, - 0xdf5f1ba6dbafdcba, 0xe8d6bcc2b7545ece, 0xd69b6a4d64cc3850, ], - [0xb3057004d66998c6, 0xb9e5e008d480602e, 0xcb401bc12a68178a, 0x9b0c25e0fec9c9ca, - 0x27903301fe272833, 0x5ab55e67746531c9, 0xa785dc1e593047b1, ], - [0xeba6857b4e021502, 0x44325a11dccd4da2, 0xfe061fabb725e7ed, 0x88ade6bf344c857e, - 0xa576bd9fdcb3b259, 0xedeae5b8be128b60, 0x0557f1891844b88a, ], - [0x94c66397aee8b97f, 0x25ac4cb55737667d, 0xc1f035a5dd2d4cc8, 0x916533f52e8205d6, - 0xf564f659b15f376a, 0x9f0032cd56a4328f, 0xa4300a553fe15224, ], - [0xe2a4c0486179d0cb, 0x3c92c7272c4536fd, 0xc08233d9a1db1814, 0x774b36b64d2fb890, - 0xf47210158dfda27b, 0xe44f205f72b1572a, 0x93f2ac3eb28af404, ], - [0x2c657b307f0dbbae, 0xbc8c7fbae563049b, 0xb459200f00172a5e, 0x90e04fdc6dfeccda, - 0x2c0369901c0cc5ea, 0xe0ef32f033d13298, 0x2087a2aecd13db2f, ], - [0x0841fbc2bf24a2b1, 0x44eb9cb920d24a43, 0x23c415122043afc5, 0x313ece0eb0f7b6d6, - 0x273938954c49858c, 0x1dcb6a4a6cf06e6d, 0x1cce7720eb4f6f98, ], - [0x0022555dbdafaac1, 0x001a5afeb9fc4888, 0x002b1f1ca992d571, 0x001fee5206bf439e, - 0x0015d27e30a1621e, 0x0015b6f958368106, 0x010a6aef986e23ce, ], - [0x00000de86b7a238e, 0x000028a51289c2f5, 0x00001b440277fe8a, 0x00000e8e3ea5103e, - 0x00000f9bc91bcf75, 0x0001071dda899dbf, 0x00001e48188120d9, ], - [0x000000126ca1da48, 0x00000013b4d8fc12, 0x0000000a11cf6ba0, 0x0000000a092e06b0, - 0x00000104497e1ca3, 0x00000017ca90627c, 0x000000a21fcd4eab, ], - [0x0000000008bc9a2d, 0x00000000070e1ecf, 0x0000000006989bf1, 0x0000000102279912, - 0x0000000012063786, 0x00000000811f1acd, 0x00000000265a4ea2, ], - [0x000000000002bb2f, 0x0000000000042512, 0x0000000001010c47, 0x00000000000ccc46, - 0x0000000000607b8a, 0x00000000001b1d04, 0x00000000000fd612, ], - [0x0000000000000198, 0x0000000000010065, 0x0000000000000834, 0x000000000000401e, - 0x0000000000001105, 0x0000000000000643, 0x0000000000000609, ], - [0x0000000000000100, 0x0000000000000004, 0x0000000000000020, 0x0000000000000008, - 0x0000000000000001, 0x0000000000000002, 0x0000000000000001, ], - ]; - - const FAST_PARTIAL_ROUND_W_HATS: [[u64; 8 - 1]; N_PARTIAL_ROUNDS] = [ - [0x269b1eb39549a1db, 0x9c2f7295da6fe4ed, 0x1cb34e7859012514, 0x28d524012a1c29c2, - 0x40eaef552e8ec873, 0x1ba83ec01c4ad111, 0xb97f43b8c7379659, ], - [0x797db014cbe89c21, 0xcd8cbe2d94b66eea, 0x1feab2f1f7800637, 0x2dfb3dfab42d3c95, - 0x026ae799f7199a65, 0xff13e93bac5ccd21, 0x85c7c686d5e86fa8, ], - [0x63491cb6f6f9b060, 0xb56e5bf1cd5c5985, 0xf617c6646887cd04, 0x82ad2d36291e4b2c, - 0x34be211a42b111f4, 0xe1427b350e8789bb, 0x4e90daa4a7162d86, ], - [0x23ff08f88b78428a, 0x2b9b6a866210f36c, 0x8f1452c156899e05, 0x5c312425f14e4701, - 0xf010bd4be5eb43dd, 0xb6e3d8976c435cd0, 0x07aae99f2fce8073, ], - [0xc89ef5941b95831b, 0x95931df88bb238d9, 0x0de74ab8bc5ec419, 0x4825380b2d936c13, - 0xb88277e244b69fb6, 0x76114374d9652c44, 0x76ed6bba7d8313c1, ], - [0xc000f50a6bd73faf, 0x9dd8304a9bd9f1b6, 0xb58e0b5e3e40bb29, 0x823c1c7be983035e, - 0xe3fa343aae9e7831, 0x7aa8d38188f752cb, 0xea42c23ed57c33c0, ], - [0x24ecf72c180fc92b, 0x33a4dbfddf7e373b, 0x469df558ba1261c2, 0x60ab4f0f3d2ad4c8, - 0xc110cb1c5c7a7a88, 0x4a4baf941ec7cf67, 0x16965340c1d488ef, ], - [0x79a95b95aa2fd971, 0x04419bf145fd6a4a, 0x71d788554e0d115d, 0x4044371afe7450e1, - 0xb00d7baa7ce81dd6, 0xe46a1479821e235b, 0x80edef59f7553c3f, ], - [0xf1dc222706620f79, 0xfc7232469c59f586, 0x028aef7f4ec9d3d4, 0xf12a3b4e5de9facb, - 0x135973e4aa6b1253, 0xcbff3378151eb32e, 0x034c61764a8d260a, ], - [0x00e52733564fcee6, 0x0c5b3ad3251ccdf4, 0xf49fffc683ce919b, 0xd17292effcfbaa02, - 0xa151d073be3aeb67, 0x2faf5b05065f340f, 0x513705952d8185c8, ], - [0x399e416f7506e439, 0xebf6618c65c571f5, 0x7a4348f382135c3a, 0x171cc2b625ec95f9, - 0x63bff2edafa923af, 0x1f0aa3a5b6c61920, 0xc8f889e2c89fc18c, ], - [0xcba09835c5a7c1fc, 0xfe9ca6a5f9cfe7f5, 0xae51732c9ae24e99, 0xfe19c95080c5fed7, - 0x56d181fad0512be3, 0xb74c82e5a32566eb, 0xfdff5523a2096934, ], - [0x4e9d731c839a6384, 0xa6ab3d286a385a74, 0x92c9a99c9c3d66f1, 0xe3e3cd56f3de8405, - 0x51afd4ef5b764ecc, 0x20f06b5b9cc5911a, 0xd5ab74758e45a1e9, ], - [0x1b40e9633dbe3e6a, 0x61aaf01dddefc2a2, 0xcca587c064e6fa34, 0xfba6904b9a40507b, - 0xbdd6f9280d82b8c2, 0x81ae47de86e77b1a, 0x240a15880d36689b, ], - [0x26136c701690ea6f, 0xfd69557e6072cfb7, 0x58d824017b513eb9, 0x05d7dafb3de8cf5e, - 0xcceb095959c76f7d, 0x83021ef00b804c28, 0x249ac764258cc526, ], - [0xe154d3c75894d969, 0xed0d19dd7a62c62d, 0x33098c41f542ad56, 0x0a00d8de37b9e97e, - 0x4701f379b9cc1b8d, 0xfcf4a08ebee38a80, 0x538455bf65ac55e5, ], - [0xd6bce6dee03ffd40, 0x1b595cc58ad8b6cd, 0x3a57b9cfcbbd1181, 0x5eca20dbf78b6fdf, - 0xf17b83b69550c7ba, 0xa25ad9bb6f6d696f, 0xa7c0a32028a396cd, ], - [0x7074ed0a4493e0cb, 0xaf007f0e547fcdae, 0x1c9a20122a92a480, 0xa394fda7dc2a248c, - 0x9011f48bc126c4ef, 0xfecd3befc1ee4d0b, 0x24b9a7dbf43d5a2b, ], - [0x1ecc6172a78fda5a, 0x654b8deec4e920d2, 0x813eb0e016ae4570, 0x3303807aaa79ad24, - 0xffa5a9ee2ad77929, 0x32ecc1c7d9d0b127, 0x6df4612b0b81b271, ], - [0xdbc7f712822f4575, 0x88e67f35f99b7fe1, 0xf37566abe5e5dbc1, 0xcd8eca65a17c493f, - 0x3568726b02cd955b, 0x1221e6d90b408c61, 0x01c8c201d650b222, ], - [0x02ed134db31e582d, 0x503692ee719f6add, 0xeadaef5785f69755, 0x98ab6d6ac1763ac2, - 0x7a12232114fa6b11, 0x5f1232b59a635f7f, 0x73e5509bf404a257, ], - [0x11c759d7c36ae70a, 0x3f7bfed8879b0281, 0x56127c65148822bd, 0x31f695e2c256d94e, - 0x31da9505206208ba, 0xb9fdbd9aada98a78, 0xc9255cd2a9ee89a3, ], - ]; - - // NB: This is in ROW-major order to support cache-friendly pre-multiplication. - const FAST_PARTIAL_ROUND_INITIAL_MATRIX: [[u64; 8 - 1]; 8 - 1] = [ - [0x44f68560bbf3e205, 0x22f2a0308e9c911f, 0x2cf2fc34afb5e90d, 0xdfd3820dd14dca23, - 0xc8cedeb0115d4cb9, 0xa7e9f1e59b2ace9e, 0x551386ca3a31ccb4, ], - [0xb4257d684cc96d30, 0x6918b8409b32d75b, 0xf42a3433a147167a, 0xaf91167a1880c1b1, - 0xa56b1fba7844632a, 0x27a3a6aa3cd42312, 0xa7e9f1e59b2ace9e, ], - [0xeb1bdec94099409a, 0x8666bcbe8366cb0f, 0x60aa4f11c97e774d, 0x9e0d98f4429fc32b, - 0xb428d8df399e3344, 0xa56b1fba7844632a, 0xc8cedeb0115d4cb9, ], - [0x67ba59d3d88a20df, 0x1d448e0422470936, 0x159c5a4decc6b1f9, 0x3f4325c2395f5587, - 0x9e0d98f4429fc32b, 0xaf91167a1880c1b1, 0xdfd3820dd14dca23, ], - [0x22c4f8e67637ae91, 0x1c0d1308d0a0148d, 0xa0ce3dcce54586f7, 0x159c5a4decc6b1f9, - 0x60aa4f11c97e774d, 0xf42a3433a147167a, 0x2cf2fc34afb5e90d, ], - [0xfb640823e5ee3bac, 0xdb990b6d9cf010db, 0x1c0d1308d0a0148d, 0x1d448e0422470936, - 0x8666bcbe8366cb0f, 0x6918b8409b32d75b, 0x22f2a0308e9c911f, ], - [0x8cf5bd0b11cfcdf1, 0xfb640823e5ee3bac, 0x22c4f8e67637ae91, 0x67ba59d3d88a20df, - 0xeb1bdec94099409a, 0xb4257d684cc96d30, 0x44f68560bbf3e205, ], - ]; -} - -#[rustfmt::skip] -impl Poseidon<12> for GoldilocksField { +impl Poseidon for GoldilocksField { // The MDS matrix we use is the circulant matrix with first row given by the vector // [ 2^x for x in MDS_MATRIX_EXPS] = [1, 1, 2, 1, 8, 32, 2, 256, 4096, 8, 65536, 1024] // @@ -442,8 +302,6 @@ mod tests { 0x37804ed8ca07fcd5, 0xe78ec2f213e28456, 0xecf67d2aacb4dbe3, 0xad14575187c496ca, ]), ]; - check_test_vectors::(test_vectors8); - #[rustfmt::skip] let test_vectors12: Vec<([u64; 12], [u64; 12])> = vec![ ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], @@ -468,12 +326,11 @@ mod tests { 0x27eca78818ef9c27, 0xf08c93583c24dc47, 0x1c9e1552c07a9f73, 0x7659179192cfdc88, ]), ]; - check_test_vectors::(test_vectors12); + check_test_vectors::(test_vectors12); } #[test] fn consistency() { - check_consistency::(); - check_consistency::(); + check_consistency::(); } } diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index 87b0512d..80451fe5 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -1,21 +1,24 @@ use std::convert::TryInto; +use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::RichField; use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; -use crate::hash::hashing::{permute, SPONGE_RATE, SPONGE_WIDTH}; +use crate::hash::hashing::{PlonkyPermutation, SPONGE_RATE, SPONGE_WIDTH}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; use crate::plonk::proof::{OpeningSet, OpeningSetTarget}; /// Observes prover messages, and generates challenges by hashing the transcript, a la Fiat-Shamir. #[derive(Clone)] -pub struct Challenger { +pub struct Challenger> { sponge_state: [F; SPONGE_WIDTH], input_buffer: Vec, output_buffer: Vec, + _phantom: PhantomData, } /// Observes prover messages, and generates verifier challenges based on the transcript. @@ -26,12 +29,13 @@ pub struct Challenger { /// design, but it can be viewed as a duplex sponge whose inputs are sometimes zero (when we perform /// multiple squeezes) and whose outputs are sometimes ignored (when we perform multiple /// absorptions). Thus the security properties of a duplex sponge still apply to our design. -impl Challenger { - pub fn new() -> Challenger { +impl> Challenger { + pub fn new() -> Challenger { Challenger { sponge_state: [F::ZERO; SPONGE_WIDTH], input_buffer: Vec::new(), output_buffer: Vec::new(), + _phantom: Default::default(), } } @@ -90,22 +94,26 @@ impl Challenger { } } - pub fn observe_hash(&mut self, hash: &HashOut) { - self.observe_elements(&hash.elements) + pub fn observe_hash>(&mut self, hash: OH::Hash) { + let felts: Vec = hash.into(); + self.observe_elements(&felts) } - pub fn observe_cap(&mut self, cap: &MerkleCap) { - for hash in &cap.0 { - self.observe_elements(&hash.elements) + pub fn observe_cap>(&mut self, cap: &MerkleCap) { + for &hash in &cap.0 { + self.observe_hash::(hash); } } - pub fn get_challenge(&mut self) -> F { - self.absorb_buffered_inputs(); + pub fn get_challenge, const D: usize>(&mut self) -> F { + self.absorb_buffered_inputs::(); if self.output_buffer.is_empty() { // Evaluate the permutation to produce `r` new outputs. - self.sponge_state = permute(self.sponge_state); + self.sponge_state = + <>::InnerHasher as AlgebraicHasher>::Permutation::permute( + self.sponge_state, + ); self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); } @@ -114,39 +122,49 @@ impl Challenger { .expect("Output buffer should be non-empty") } - pub fn get_n_challenges(&mut self, n: usize) -> Vec { - (0..n).map(|_| self.get_challenge()).collect() + pub fn get_n_challenges, const D: usize>( + &mut self, + n: usize, + ) -> Vec { + (0..n).map(|_| self.get_challenge::()).collect() } - pub fn get_hash(&mut self) -> HashOut { + pub fn get_hash, const D: usize>(&mut self) -> HashOut { HashOut { elements: [ - self.get_challenge(), - self.get_challenge(), - self.get_challenge(), - self.get_challenge(), + self.get_challenge::(), + self.get_challenge::(), + self.get_challenge::(), + self.get_challenge::(), ], } } - pub fn get_extension_challenge(&mut self) -> F::Extension + pub fn get_extension_challenge, const D: usize>( + &mut self, + ) -> F::Extension where F: Extendable, { let mut arr = [F::ZERO; D]; - arr.copy_from_slice(&self.get_n_challenges(D)); + arr.copy_from_slice(&self.get_n_challenges::(D)); F::Extension::from_basefield_array(arr) } - pub fn get_n_extension_challenges(&mut self, n: usize) -> Vec + pub fn get_n_extension_challenges, const D: usize>( + &mut self, + n: usize, + ) -> Vec where F: Extendable, { - (0..n).map(|_| self.get_extension_challenge()).collect() + (0..n) + .map(|_| self.get_extension_challenge::()) + .collect() } /// Absorb any buffered inputs. After calling this, the input buffer will be empty. - fn absorb_buffered_inputs(&mut self) { + fn absorb_buffered_inputs, const D: usize>(&mut self) { if self.input_buffer.is_empty() { return; } @@ -160,7 +178,10 @@ impl Challenger { } // Apply the permutation. - self.sponge_state = permute(self.sponge_state); + self.sponge_state = + <>::InnerHasher as AlgebraicHasher>::Permutation::permute( + self.sponge_state, + ); } self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); @@ -169,23 +190,21 @@ impl Challenger { } } -impl Default for Challenger { +impl> Default for Challenger { fn default() -> Self { Self::new() } } /// A recursive version of `Challenger`. -pub struct RecursiveChallenger { +pub struct RecursiveChallenger, H: AlgebraicHasher, const D: usize> { sponge_state: [Target; SPONGE_WIDTH], input_buffer: Vec, output_buffer: Vec, } -impl RecursiveChallenger { - pub(crate) fn new, const D: usize>( - builder: &mut CircuitBuilder, - ) -> Self { +impl, H: AlgebraicHasher, const D: usize> RecursiveChallenger { + pub(crate) fn new(builder: &mut CircuitBuilder) -> Self { let zero = builder.zero(); RecursiveChallenger { sponge_state: [zero; SPONGE_WIDTH], @@ -207,7 +226,7 @@ impl RecursiveChallenger { } } - pub fn observe_opening_set(&mut self, os: &OpeningSetTarget) { + pub fn observe_opening_set(&mut self, os: &OpeningSetTarget) { let OpeningSetTarget { constants, plonk_sigmas, @@ -240,25 +259,22 @@ impl RecursiveChallenger { } } - pub fn observe_extension_element(&mut self, element: ExtensionTarget) { + pub fn observe_extension_element(&mut self, element: ExtensionTarget) { self.observe_elements(&element.0); } - pub fn observe_extension_elements(&mut self, elements: &[ExtensionTarget]) { + pub fn observe_extension_elements(&mut self, elements: &[ExtensionTarget]) { for &element in elements { self.observe_extension_element(element); } } - pub(crate) fn get_challenge, const D: usize>( - &mut self, - builder: &mut CircuitBuilder, - ) -> Target { + pub(crate) fn get_challenge(&mut self, builder: &mut CircuitBuilder) -> Target { self.absorb_buffered_inputs(builder); if self.output_buffer.is_empty() { // Evaluate the permutation to produce `r` new outputs. - self.sponge_state = builder.permute(self.sponge_state); + self.sponge_state = builder.permute::(self.sponge_state); self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); } @@ -267,7 +283,7 @@ impl RecursiveChallenger { .expect("Output buffer should be non-empty") } - pub(crate) fn get_n_challenges, const D: usize>( + pub(crate) fn get_n_challenges( &mut self, builder: &mut CircuitBuilder, n: usize, @@ -275,10 +291,7 @@ impl RecursiveChallenger { (0..n).map(|_| self.get_challenge(builder)).collect() } - pub fn get_hash, const D: usize>( - &mut self, - builder: &mut CircuitBuilder, - ) -> HashOutTarget { + pub fn get_hash(&mut self, builder: &mut CircuitBuilder) -> HashOutTarget { HashOutTarget { elements: [ self.get_challenge(builder), @@ -289,7 +302,7 @@ impl RecursiveChallenger { } } - pub fn get_extension_challenge, const D: usize>( + pub fn get_extension_challenge( &mut self, builder: &mut CircuitBuilder, ) -> ExtensionTarget { @@ -297,10 +310,7 @@ impl RecursiveChallenger { } /// Absorb any buffered inputs. After calling this, the input buffer will be empty. - fn absorb_buffered_inputs, const D: usize>( - &mut self, - builder: &mut CircuitBuilder, - ) { + fn absorb_buffered_inputs(&mut self, builder: &mut CircuitBuilder) { if self.input_buffer.is_empty() { return; } @@ -314,7 +324,7 @@ impl RecursiveChallenger { } // Apply the permutation. - self.sponge_state = builder.permute(self.sponge_state); + self.sponge_state = builder.permute::(self.sponge_state); } self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); @@ -333,15 +343,18 @@ mod tests { use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] fn no_duplicate_challenges() { - type F = GoldilocksField; - let mut challenger = Challenger::new(); + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let mut challenger = Challenger::>::InnerHasher>::new(); let mut challenges = Vec::new(); for i in 1..10 { - challenges.extend(challenger.get_n_challenges(i)); + challenges.extend(challenger.get_n_challenges::(i)); challenger.observe_element(F::rand()); } @@ -356,7 +369,9 @@ mod tests { /// Tests for consistency between `Challenger` and `RecursiveChallenger`. #[test] fn test_consistency() { - type F = GoldilocksField; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; // These are mostly arbitrary, but we want to test some rounds with enough inputs/outputs to // trigger multiple absorptions/squeezes. @@ -369,16 +384,17 @@ mod tests { .map(|&n| F::rand_vec(n)) .collect(); - let mut challenger = Challenger::new(); + let mut challenger = Challenger::>::InnerHasher>::new(); let mut outputs_per_round: Vec> = Vec::new(); for (r, inputs) in inputs_per_round.iter().enumerate() { challenger.observe_elements(inputs); - outputs_per_round.push(challenger.get_n_challenges(num_outputs_per_round[r])); + outputs_per_round.push(challenger.get_n_challenges::(num_outputs_per_round[r])); } let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let mut recursive_challenger = RecursiveChallenger::new(&mut builder); + let mut builder = CircuitBuilder::::new(config); + let mut recursive_challenger = + RecursiveChallenger::>::InnerHasher, D>::new(&mut builder); let mut recursive_outputs_per_round: Vec> = Vec::new(); for (r, inputs) in inputs_per_round.iter().enumerate() { recursive_challenger.observe_elements(&builder.constants(inputs)); @@ -386,7 +402,7 @@ mod tests { recursive_challenger.get_n_challenges(&mut builder, num_outputs_per_round[r]), ); } - let circuit = builder.build(); + let circuit = builder.build::(); let inputs = PartialWitness::new(); let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common); let recursive_output_values_per_round: Vec> = recursive_outputs_per_round diff --git a/src/iop/generator.rs b/src/iop/generator.rs index eb2c95f7..e71ef8cd 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -3,19 +3,25 @@ use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field_types::{Field, RichField}; +use crate::field::field_types::Field; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::{PartialWitness, PartitionWitness, Witness}; use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; +use crate::plonk::config::GenericConfig; /// Given a `PartitionWitness` that has only inputs set, populates the rest of the witness using the /// given set of generators. -pub(crate) fn generate_partial_witness<'a, F: RichField + Extendable, const D: usize>( +pub(crate) fn generate_partial_witness< + 'a, + F: Extendable, + C: GenericConfig, + const D: usize, +>( inputs: PartialWitness, - prover_data: &'a ProverOnlyCircuitData, - common_data: &'a CommonCircuitData, + prover_data: &'a ProverOnlyCircuitData, + common_data: &'a CommonCircuitData, ) -> PartitionWitness<'a, F> { let config = &common_data.config; let generators = &prover_data.generators; diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 858bacd9..09e0a73e 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -3,12 +3,13 @@ use std::convert::TryInto; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field_types::Field; +use crate::field::field_types::{Field, RichField}; use crate::hash::hash_types::HashOutTarget; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; +use crate::plonk::config::AlgebraicHasher; /// A witness holds information on the values of targets in a circuit. pub trait Witness { @@ -82,7 +83,13 @@ pub trait Witness { .for_each(|(&t, x)| self.set_target(t, x)); } - fn set_cap_target(&mut self, ct: &MerkleCapTarget, value: &MerkleCap) { + fn set_cap_target>( + &mut self, + ct: &MerkleCapTarget, + value: &MerkleCap, + ) where + F: RichField, + { for (ht, h) in ct.0.iter().zip(&value.0) { self.set_hash_target(*ht, *h); } diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 5dcde1e0..888ee616 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -22,7 +22,6 @@ use crate::gates::public_input::PublicInputGate; use crate::gates::random_access::RandomAccessGate; use crate::gates::switch::SwitchGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget}; -use crate::hash::hashing::hash_n_to_hash; use crate::iop::generator::{ CopyGenerator, RandomValueGenerator, SimpleGenerator, WitnessGenerator, }; @@ -32,6 +31,7 @@ use crate::plonk::circuit_data::{ CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData, VerifierCircuitData, VerifierOnlyCircuitData, }; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::copy_constraint::CopyConstraint; use crate::plonk::permutation_argument::Forest; use crate::plonk::plonk_common::PlonkPolynomials; @@ -655,7 +655,7 @@ impl, const D: usize> CircuitBuilder { } /// Builds a "full circuit", with both prover and verifier data. - pub fn build(mut self) -> CircuitData { + pub fn build>(mut self) -> CircuitData { let mut timing = TimingTree::new("preprocess", Level::Trace); let start = Instant::now(); @@ -665,7 +665,8 @@ impl, const D: usize> CircuitBuilder { // Hash the public inputs, and route them to a `PublicInputGate` which will enforce that // those hash wires match the claimed public inputs. - let public_inputs_hash = self.hash_n_to_hash(self.public_inputs.clone(), true); + let public_inputs_hash = + self.hash_n_to_hash::(self.public_inputs.clone(), true); let pi_gate = self.add_gate(PublicInputGate, vec![]); for (&hash_part, wire) in public_inputs_hash .elements @@ -784,7 +785,7 @@ impl, const D: usize> CircuitBuilder { constants_sigmas_cap.flatten(), vec![/* Add other circuit data here */], ]; - let circuit_digest = hash_n_to_hash(circuit_digest_parts.concat(), false); + let circuit_digest = C::Hasher::hash(circuit_digest_parts.concat(), false); let common = CommonCircuitData { config: self.config, @@ -809,7 +810,7 @@ impl, const D: usize> CircuitBuilder { } /// Builds a "prover circuit", with data needed to generate proofs but not verify them. - pub fn build_prover(self) -> ProverCircuitData { + pub fn build_prover>(self) -> ProverCircuitData { // TODO: Can skip parts of this. let CircuitData { prover_only, @@ -823,7 +824,7 @@ impl, const D: usize> CircuitBuilder { } /// Builds a "verifier circuit", with data needed to verify proofs but not generate them. - pub fn build_verifier(self) -> VerifierCircuitData { + pub fn build_verifier>(self) -> VerifierCircuitData { // TODO: Can skip parts of this. let CircuitData { verifier_only, diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 2f98d0a6..8419b7b9 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -5,16 +5,16 @@ use anyhow::Result; use crate::field::extension_field::Extendable; use crate::field::fft::FftRootTable; -use crate::field::field_types::{Field, RichField}; use crate::fri::commitment::PolynomialBatchCommitment; use crate::fri::reduction_strategies::FriReductionStrategy; use crate::fri::{FriConfig, FriParams}; use crate::gates::gate::PrefixedGate; -use crate::hash::hash_types::{HashOut, MerkleCapTarget}; +use crate::hash::hash_types::MerkleCapTarget; use crate::hash::merkle_tree::MerkleCap; use crate::iop::generator::WitnessGenerator; use crate::iop::target::Target; use crate::iop::witness::PartialWitness; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::proof::ProofWithPublicInputs; use crate::plonk::prover::prove; use crate::plonk::verifier::verify; @@ -77,14 +77,14 @@ impl CircuitConfig { } /// Circuit data required by the prover or the verifier. -pub struct CircuitData, const D: usize> { - pub(crate) prover_only: ProverOnlyCircuitData, - pub(crate) verifier_only: VerifierOnlyCircuitData, - pub(crate) common: CommonCircuitData, +pub struct CircuitData, C: GenericConfig, const D: usize> { + pub(crate) prover_only: ProverOnlyCircuitData, + pub(crate) verifier_only: VerifierOnlyCircuitData, + pub(crate) common: CommonCircuitData, } -impl, const D: usize> CircuitData { - pub fn prove(&self, inputs: PartialWitness) -> Result> { +impl, C: GenericConfig, const D: usize> CircuitData { + pub fn prove(&self, inputs: PartialWitness) -> Result> { prove( &self.prover_only, &self.common, @@ -93,7 +93,7 @@ impl, const D: usize> CircuitData { ) } - pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { + pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { verify(proof_with_pis, &self.verifier_only, &self.common) } } @@ -105,13 +105,13 @@ impl, const D: usize> CircuitData { /// structure as succinct as we can. Thus we include various precomputed data which isn't strictly /// required, like LDEs of preprocessed polynomials. If more succinctness was desired, we could /// construct a more minimal prover structure and convert back and forth. -pub struct ProverCircuitData, const D: usize> { - pub(crate) prover_only: ProverOnlyCircuitData, - pub(crate) common: CommonCircuitData, +pub struct ProverCircuitData, C: GenericConfig, const D: usize> { + pub(crate) prover_only: ProverOnlyCircuitData, + pub(crate) common: CommonCircuitData, } -impl, const D: usize> ProverCircuitData { - pub fn prove(&self, inputs: PartialWitness) -> Result> { +impl, C: GenericConfig, const D: usize> ProverCircuitData { + pub fn prove(&self, inputs: PartialWitness) -> Result> { prove( &self.prover_only, &self.common, @@ -123,25 +123,29 @@ impl, const D: usize> ProverCircuitData { /// Circuit data required by the prover. #[derive(Debug)] -pub struct VerifierCircuitData, const D: usize> { - pub(crate) verifier_only: VerifierOnlyCircuitData, - pub(crate) common: CommonCircuitData, +pub struct VerifierCircuitData, C: GenericConfig, const D: usize> { + pub(crate) verifier_only: VerifierOnlyCircuitData, + pub(crate) common: CommonCircuitData, } -impl, const D: usize> VerifierCircuitData { - pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { +impl, C: GenericConfig, const D: usize> VerifierCircuitData { + pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { verify(proof_with_pis, &self.verifier_only, &self.common) } } /// Circuit data required by the prover, but not the verifier. -pub(crate) struct ProverOnlyCircuitData, const D: usize> { +pub(crate) struct ProverOnlyCircuitData< + F: Extendable, + C: GenericConfig, + const D: usize, +> { pub generators: Vec>>, /// Generator indices (within the `Vec` above), indexed by the representative of each target /// they watch. pub generator_indices_by_watches: BTreeMap>, /// Commitments to the constants polynomials and sigma polynomials. - pub constants_sigmas_commitment: PolynomialBatchCommitment, + pub constants_sigmas_commitment: PolynomialBatchCommitment, /// The transpose of the list of sigma polynomials. pub sigmas: Vec>, /// Subgroup of order `degree`. @@ -159,14 +163,14 @@ pub(crate) struct ProverOnlyCircuitData, const D: u /// Circuit data required by the verifier, but not the prover. #[derive(Debug)] -pub(crate) struct VerifierOnlyCircuitData { +pub(crate) struct VerifierOnlyCircuitData, const D: usize> { /// A commitment to each constant polynomial and each permutation polynomial. - pub(crate) constants_sigmas_cap: MerkleCap, + pub(crate) constants_sigmas_cap: MerkleCap, } /// Circuit data required by both the prover and the verifier. #[derive(Debug)] -pub struct CommonCircuitData, const D: usize> { +pub struct CommonCircuitData, C: GenericConfig, const D: usize> { pub(crate) config: CircuitConfig, pub(crate) fri_params: FriParams, @@ -196,10 +200,10 @@ pub struct CommonCircuitData, const D: usize> { /// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to /// seed Fiat-Shamir. - pub(crate) circuit_digest: HashOut, + pub(crate) circuit_digest: <>::Hasher as Hasher>::Hash, } -impl, const D: usize> CommonCircuitData { +impl, C: GenericConfig, const D: usize> CommonCircuitData { pub fn degree(&self) -> usize { 1 << self.degree_bits } diff --git a/src/plonk/config.rs b/src/plonk/config.rs new file mode 100644 index 00000000..1eec858d --- /dev/null +++ b/src/plonk/config.rs @@ -0,0 +1,267 @@ +use std::convert::TryInto; +use std::fmt::Debug; + +use keccak_hash::keccak; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::field::extension_field::quadratic::QuadraticExtension; +use crate::field::extension_field::{Extendable, FieldExtension}; +use crate::field::field_types::RichField; +use crate::field::goldilocks_field::GoldilocksField; +use crate::gates::poseidon::PoseidonGate; +use crate::hash::hash_types::HashOut; +use crate::hash::hashing::{ + compress, hash_n_to_hash, PlonkyPermutation, PoseidonPermutation, SPONGE_WIDTH, +}; +use crate::iop::target::{BoolTarget, Target}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::ceil_div_usize; +use crate::util::serialization::Buffer; + +pub trait Hasher: Sized + Clone + Debug + Eq + PartialEq { + /// Size of `Hash` in bytes. + const HASH_SIZE: usize; + type Hash: From> + + Into> + + Into> + + Into + + Copy + + Clone + + Debug + + Eq + + PartialEq + + Send + + Sync + + Serialize + + DeserializeOwned; + + fn hash(input: Vec, pad: bool) -> Self::Hash; + fn two_to_one(left: Self::Hash, right: Self::Hash) -> Self::Hash; +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct PoseidonHash; +impl Hasher for PoseidonHash { + const HASH_SIZE: usize = 4 * 8; + type Hash = HashOut; + + fn hash(input: Vec, pad: bool) -> Self::Hash { + hash_n_to_hash::>::Permutation>(input, pad) + } + + fn two_to_one(left: Self::Hash, right: Self::Hash) -> Self::Hash { + compress::>::Permutation>(left, right) + } +} + +impl AlgebraicHasher for PoseidonHash { + type Permutation = PoseidonPermutation; + + fn permute_swapped( + inputs: [Target; SPONGE_WIDTH], + swap: BoolTarget, + builder: &mut CircuitBuilder, + ) -> [Target; SPONGE_WIDTH] + where + F: Extendable, + { + let gate_type = PoseidonGate::::new(); + let gate = builder.add_gate(gate_type, vec![]); + + let swap_wire = PoseidonGate::::WIRE_SWAP; + let swap_wire = Target::wire(gate, swap_wire); + builder.connect(swap.target, swap_wire); + + // Route input wires. + for i in 0..SPONGE_WIDTH { + let in_wire = PoseidonGate::::wire_input(i); + let in_wire = Target::wire(gate, in_wire); + builder.connect(inputs[i], in_wire); + } + + // Collect output wires. + (0..SPONGE_WIDTH) + .map(|i| Target::wire(gate, PoseidonGate::::wire_output(i))) + .collect::>() + .try_into() + .unwrap() + } +} + +// TODO: Remove width from `GMiMCGate` to make this work. +// #[derive(Copy, Clone, Debug, Eq, PartialEq)] +// pub struct GMiMCHash; +// impl Hasher for GMiMCHash { +// const HASH_SIZE: usize = 4 * 8; +// type Hash = HashOut; +// +// fn hash(input: Vec, pad: bool) -> Self::Hash { +// hash_n_to_hash::>::Permutation>(input, pad) +// } +// +// fn two_to_one(left: Self::Hash, right: Self::Hash) -> Self::Hash { +// compress::>::Permutation>(left, right) +// } +// } +// +// impl AlgebraicHasher for GMiMCHash { +// type Permutation = GMiMCPermutation; +// +// fn permute_swapped( +// inputs: [Target; WIDTH], +// swap: BoolTarget, +// builder: &mut CircuitBuilder, +// ) -> [Target; WIDTH] +// where +// F: Extendable, +// { +// let gate_type = GMiMCGate::::new(); +// let gate = builder.add_gate(gate_type, vec![]); +// +// let swap_wire = GMiMCGate::::WIRE_SWAP; +// let swap_wire = Target::wire(gate, swap_wire); +// builder.connect(swap.target, swap_wire); +// +// // Route input wires. +// for i in 0..W { +// let in_wire = GMiMCGate::::wire_input(i); +// let in_wire = Target::wire(gate, in_wire); +// builder.connect(inputs[i], in_wire); +// } +// +// // Collect output wires. +// (0..W) +// .map(|i| Target::wire(gate, input: GMiMCGate::wire_output(i))) +// .collect::>() +// .try_into() +// .unwrap() +// } +// } + +pub trait AlgebraicHasher: Hasher> { + // TODO: Adding a `const WIDTH: usize` here yields a compiler error down the line. + // Maybe try again in a while. + type Permutation: PlonkyPermutation; + fn permute_swapped( + inputs: [Target; SPONGE_WIDTH], + swap: BoolTarget, + builder: &mut CircuitBuilder, + ) -> [Target; SPONGE_WIDTH] + where + F: Extendable; +} + +pub trait GenericConfig: + Debug + Clone + Sync + Sized + Send + Eq + PartialEq +{ + type F: RichField + Extendable; + type FE: FieldExtension; + type Hasher: Hasher; + type InnerHasher: AlgebraicHasher; +} + +pub trait AlgebraicConfig: + Debug + Clone + Sync + Sized + Send + Eq + PartialEq +{ + type F: RichField + Extendable; + type FE: FieldExtension; + type Hasher: AlgebraicHasher; + type InnerHasher: AlgebraicHasher; +} + +impl, const D: usize> GenericConfig for A { + type F = >::F; + type FE = >::FE; + type Hasher = >::Hasher; + type InnerHasher = >::InnerHasher; +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct PoseidonGoldilocksConfig; +impl AlgebraicConfig<2> for PoseidonGoldilocksConfig { + type F = GoldilocksField; + type FE = QuadraticExtension; + type Hasher = PoseidonHash; + type InnerHasher = PoseidonHash; +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct KeccakHash; +impl Hasher for KeccakHash { + const HASH_SIZE: usize = N; + type Hash = BytesHash; + + fn hash(input: Vec, _pad: bool) -> Self::Hash { + let mut buffer = Buffer::new(Vec::new()); + buffer.write_field_vec(&input).unwrap(); + let mut arr = [0; N]; + arr.copy_from_slice(&keccak(buffer.bytes()).0[..N]); + BytesHash(arr) + } + + fn two_to_one(left: Self::Hash, right: Self::Hash) -> Self::Hash { + let mut v = vec![0; N * 2]; + v[0..N].copy_from_slice(&left.0); + v[N..].copy_from_slice(&right.0); + let mut arr = [0; N]; + arr.copy_from_slice(&keccak(v).0[..N]); + BytesHash(arr) + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct KeccakGoldilocksConfig; +impl GenericConfig<2> for KeccakGoldilocksConfig { + type F = GoldilocksField; + type FE = QuadraticExtension; + type Hasher = KeccakHash<25>; + type InnerHasher = PoseidonHash; +} + +#[derive(Eq, PartialEq, Copy, Clone, Debug)] +pub struct BytesHash([u8; N]); +impl Serialize for BytesHash { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + todo!() + } +} +impl<'de, const N: usize> Deserialize<'de> for BytesHash { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + todo!() + } +} + +impl From> for BytesHash { + fn from(v: Vec) -> Self { + Self(v.try_into().unwrap()) + } +} + +impl From> for Vec { + fn from(hash: BytesHash) -> Self { + hash.0.to_vec() + } +} + +impl From> for u64 { + fn from(hash: BytesHash) -> Self { + u64::from_le_bytes(hash.0[..8].try_into().unwrap()) + } +} + +impl From> for Vec { + fn from(hash: BytesHash) -> Self { + let n = hash.0.len(); + let mut v = hash.0.to_vec(); + v.resize(ceil_div_usize(n, 8) * 8, 0); + let mut buffer = Buffer::new(v); + buffer.read_field_vec(buffer.len() / 8).unwrap() + } +} diff --git a/src/plonk/get_challenges.rs b/src/plonk/get_challenges.rs index ab361c76..99436512 100644 --- a/src/plonk/get_challenges.rs +++ b/src/plonk/get_challenges.rs @@ -1,14 +1,12 @@ use std::collections::HashSet; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; use crate::fri::proof::{CompressedFriProof, FriProof}; use crate::fri::verifier::{compute_evaluation, fri_combine_initial, PrecomputedReducedEvals}; -use crate::hash::hash_types::HashOut; -use crate::hash::hashing::hash_n_to_1; use crate::hash::merkle_tree::MerkleCap; use crate::iop::challenger::Challenger; use crate::plonk::circuit_data::CommonCircuitData; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::proof::{ CompressedProof, CompressedProofWithPublicInputs, FriInferredElements, OpeningSet, Proof, ProofChallenges, ProofWithPublicInputs, @@ -16,67 +14,68 @@ use crate::plonk::proof::{ use crate::polynomial::polynomial::PolynomialCoeffs; use crate::util::reverse_bits; -fn get_challenges, const D: usize>( - public_inputs_hash: HashOut, - wires_cap: &MerkleCap, - plonk_zs_partial_products_cap: &MerkleCap, - quotient_polys_cap: &MerkleCap, +fn get_challenges, C: GenericConfig, const D: usize>( + public_inputs_hash: <>::InnerHasher as Hasher>::Hash, + wires_cap: &MerkleCap, + plonk_zs_partial_products_cap: &MerkleCap, + quotient_polys_cap: &MerkleCap, openings: &OpeningSet, - commit_phase_merkle_caps: &[MerkleCap], + commit_phase_merkle_caps: &[MerkleCap], final_poly: &PolynomialCoeffs, pow_witness: F, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> anyhow::Result> { let config = &common_data.config; let num_challenges = config.num_challenges; let num_fri_queries = config.fri_config.num_query_rounds; let lde_size = common_data.lde_size(); - let mut challenger = Challenger::new(); + let mut challenger = Challenger::::new(); // Observe the instance. - challenger.observe_hash(&common_data.circuit_digest); - challenger.observe_hash(&public_inputs_hash); + challenger.observe_hash::(common_data.circuit_digest); + challenger.observe_hash::(public_inputs_hash); challenger.observe_cap(wires_cap); - let plonk_betas = challenger.get_n_challenges(num_challenges); - let plonk_gammas = challenger.get_n_challenges(num_challenges); + let plonk_betas = challenger.get_n_challenges::(num_challenges); + let plonk_gammas = challenger.get_n_challenges::(num_challenges); challenger.observe_cap(plonk_zs_partial_products_cap); - let plonk_alphas = challenger.get_n_challenges(num_challenges); + let plonk_alphas = challenger.get_n_challenges::(num_challenges); challenger.observe_cap(quotient_polys_cap); - let plonk_zeta = challenger.get_extension_challenge(); + let plonk_zeta = challenger.get_extension_challenge::(); challenger.observe_opening_set(openings); // Scaling factor to combine polynomials. - let fri_alpha = challenger.get_extension_challenge(); + let fri_alpha = challenger.get_extension_challenge::(); // Recover the random betas used in the FRI reductions. let fri_betas = commit_phase_merkle_caps .iter() .map(|cap| { challenger.observe_cap(cap); - challenger.get_extension_challenge() + challenger.get_extension_challenge::() }) .collect(); challenger.observe_extension_elements(&final_poly.coeffs); - let fri_pow_response = hash_n_to_1( + let fri_pow_response = C::InnerHasher::hash( challenger - .get_hash() + .get_hash::() .elements .iter() .copied() .chain(Some(pow_witness)) .collect(), false, - ); + ) + .elements[0]; let fri_query_indices = (0..num_fri_queries) - .map(|_| challenger.get_challenge().to_canonical_u64() as usize % lde_size) + .map(|_| challenger.get_challenge::().to_canonical_u64() as usize % lde_size) .collect(); Ok(ProofChallenges { @@ -91,10 +90,10 @@ fn get_challenges, const D: usize>( }) } -impl, const D: usize> ProofWithPublicInputs { +impl, C: GenericConfig, const D: usize> ProofWithPublicInputs { pub(crate) fn fri_query_indices( &self, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> anyhow::Result> { Ok(self.get_challenges(common_data)?.fri_query_indices) } @@ -102,7 +101,7 @@ impl, const D: usize> ProofWithPublicInputs { /// Computes all Fiat-Shamir challenges used in the Plonk proof. pub(crate) fn get_challenges( &self, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> anyhow::Result> { let Proof { wires_cap, @@ -132,11 +131,13 @@ impl, const D: usize> ProofWithPublicInputs { } } -impl, const D: usize> CompressedProofWithPublicInputs { +impl, C: GenericConfig, const D: usize> + CompressedProofWithPublicInputs +{ /// Computes all Fiat-Shamir challenges used in the Plonk proof. pub(crate) fn get_challenges( &self, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> anyhow::Result> { let CompressedProof { wires_cap, @@ -169,7 +170,7 @@ impl, const D: usize> CompressedProofWithPublicInpu pub(crate) fn get_inferred_elements( &self, challenges: &ProofChallenges, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> FriInferredElements { let ProofChallenges { plonk_zeta, diff --git a/src/plonk/mod.rs b/src/plonk/mod.rs index 3b8fdd6b..b2d1ed03 100644 --- a/src/plonk/mod.rs +++ b/src/plonk/mod.rs @@ -1,5 +1,6 @@ pub mod circuit_builder; pub mod circuit_data; +pub mod config; pub(crate) mod copy_constraint; mod get_challenges; pub(crate) mod permutation_argument; diff --git a/src/plonk/proof.rs b/src/plonk/proof.rs index ce1207cd..0dff5a06 100644 --- a/src/plonk/proof.rs +++ b/src/plonk/proof.rs @@ -6,27 +6,27 @@ use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; use crate::fri::commitment::PolynomialBatchCommitment; use crate::fri::proof::{CompressedFriProof, FriProof, FriProofTarget}; -use crate::hash::hash_types::{HashOut, MerkleCapTarget}; -use crate::hash::hashing::hash_n_to_hash; +use crate::hash::hash_types::MerkleCapTarget; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::Target; use crate::plonk::circuit_data::{CommonCircuitData, VerifierOnlyCircuitData}; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::verifier::verify_with_challenges; use crate::util::serialization::Buffer; #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] -pub struct Proof, const D: usize> { +pub struct Proof, C: GenericConfig, const D: usize> { /// Merkle cap of LDEs of wire values. - pub wires_cap: MerkleCap, + pub wires_cap: MerkleCap, /// Merkle cap of LDEs of Z, in the context of Plonk's permutation argument. - pub plonk_zs_partial_products_cap: MerkleCap, + pub plonk_zs_partial_products_cap: MerkleCap, /// Merkle cap of LDEs of the quotient polynomial components. - pub quotient_polys_cap: MerkleCap, + pub quotient_polys_cap: MerkleCap, /// Purported values of each polynomial at the challenge point. pub openings: OpeningSet, /// A batch FRI argument for all openings. - pub opening_proof: FriProof, + pub opening_proof: FriProof, } pub struct ProofTarget { @@ -37,13 +37,13 @@ pub struct ProofTarget { pub opening_proof: FriProofTarget, } -impl, const D: usize> Proof { +impl, C: GenericConfig, const D: usize> Proof { /// Compress the proof. pub fn compress( self, indices: &[usize], - common_data: &CommonCircuitData, - ) -> CompressedProof { + common_data: &CommonCircuitData, + ) -> CompressedProof { let Proof { wires_cap, plonk_zs_partial_products_cap, @@ -64,16 +64,16 @@ impl, const D: usize> Proof { #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] -pub struct ProofWithPublicInputs, const D: usize> { - pub proof: Proof, +pub struct ProofWithPublicInputs, C: GenericConfig, const D: usize> { + pub proof: Proof, pub public_inputs: Vec, } -impl, const D: usize> ProofWithPublicInputs { +impl, C: GenericConfig, const D: usize> ProofWithPublicInputs { pub fn compress( self, - common_data: &CommonCircuitData, - ) -> anyhow::Result> { + common_data: &CommonCircuitData, + ) -> anyhow::Result> { let indices = self.fri_query_indices(common_data)?; let compressed_proof = self.proof.compress(&indices, common_data); Ok(CompressedProofWithPublicInputs { @@ -82,8 +82,10 @@ impl, const D: usize> ProofWithPublicInputs { }) } - pub(crate) fn get_public_inputs_hash(&self) -> HashOut { - hash_n_to_hash(self.public_inputs.clone(), true) + pub(crate) fn get_public_inputs_hash( + &self, + ) -> <>::InnerHasher as Hasher>::Hash { + C::InnerHasher::hash(self.public_inputs.clone(), true) } pub fn to_bytes(&self) -> anyhow::Result> { @@ -94,7 +96,7 @@ impl, const D: usize> ProofWithPublicInputs { pub fn from_bytes( bytes: Vec, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> anyhow::Result { let mut buffer = Buffer::new(bytes); let proof = buffer.read_proof_with_public_inputs(common_data)?; @@ -104,27 +106,27 @@ impl, const D: usize> ProofWithPublicInputs { #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] -pub struct CompressedProof, const D: usize> { +pub struct CompressedProof, C: GenericConfig, const D: usize> { /// Merkle cap of LDEs of wire values. - pub wires_cap: MerkleCap, + pub wires_cap: MerkleCap, /// Merkle cap of LDEs of Z, in the context of Plonk's permutation argument. - pub plonk_zs_partial_products_cap: MerkleCap, + pub plonk_zs_partial_products_cap: MerkleCap, /// Merkle cap of LDEs of the quotient polynomial components. - pub quotient_polys_cap: MerkleCap, + pub quotient_polys_cap: MerkleCap, /// Purported values of each polynomial at the challenge point. pub openings: OpeningSet, /// A compressed batch FRI argument for all openings. - pub opening_proof: CompressedFriProof, + pub opening_proof: CompressedFriProof, } -impl, const D: usize> CompressedProof { +impl, C: GenericConfig, const D: usize> CompressedProof { /// Decompress the proof. pub(crate) fn decompress( self, challenges: &ProofChallenges, fri_inferred_elements: FriInferredElements, - common_data: &CommonCircuitData, - ) -> Proof { + common_data: &CommonCircuitData, + ) -> Proof { let CompressedProof { wires_cap, plonk_zs_partial_products_cap, @@ -149,16 +151,22 @@ impl, const D: usize> CompressedProof { #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] -pub struct CompressedProofWithPublicInputs, const D: usize> { - pub proof: CompressedProof, +pub struct CompressedProofWithPublicInputs< + F: Extendable, + C: GenericConfig, + const D: usize, +> { + pub proof: CompressedProof, pub public_inputs: Vec, } -impl, const D: usize> CompressedProofWithPublicInputs { +impl, C: GenericConfig, const D: usize> + CompressedProofWithPublicInputs +{ pub fn decompress( self, - common_data: &CommonCircuitData, - ) -> anyhow::Result> { + common_data: &CommonCircuitData, + ) -> anyhow::Result> { let challenges = self.get_challenges(common_data)?; let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); let compressed_proof = @@ -172,8 +180,8 @@ impl, const D: usize> CompressedProofWithPublicInpu pub(crate) fn verify( self, - verifier_data: &VerifierOnlyCircuitData, - common_data: &CommonCircuitData, + verifier_data: &VerifierOnlyCircuitData, + common_data: &CommonCircuitData, ) -> anyhow::Result<()> { let challenges = self.get_challenges(common_data)?; let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); @@ -191,8 +199,10 @@ impl, const D: usize> CompressedProofWithPublicInpu ) } - pub(crate) fn get_public_inputs_hash(&self) -> HashOut { - hash_n_to_hash(self.public_inputs.clone(), true) + pub(crate) fn get_public_inputs_hash( + &self, + ) -> <>::InnerHasher as Hasher>::Hash { + C::InnerHasher::hash(self.public_inputs.clone(), true) } pub fn to_bytes(&self) -> anyhow::Result> { @@ -203,7 +213,7 @@ impl, const D: usize> CompressedProofWithPublicInpu pub fn from_bytes( bytes: Vec, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> anyhow::Result { let mut buffer = Buffer::new(bytes); let proof = buffer.read_compressed_proof_with_public_inputs(common_data)?; @@ -258,17 +268,17 @@ pub struct OpeningSet, const D: usize> { pub quotient_polys: Vec, } -impl, const D: usize> OpeningSet { - pub fn new( +impl, const D: usize> OpeningSet { + pub fn new>( z: F::Extension, g: F::Extension, - constants_sigmas_commitment: &PolynomialBatchCommitment, - wires_commitment: &PolynomialBatchCommitment, - zs_partial_products_commitment: &PolynomialBatchCommitment, - quotient_polys_commitment: &PolynomialBatchCommitment, - common_data: &CommonCircuitData, + constants_sigmas_commitment: &PolynomialBatchCommitment, + wires_commitment: &PolynomialBatchCommitment, + zs_partial_products_commitment: &PolynomialBatchCommitment, + quotient_polys_commitment: &PolynomialBatchCommitment, + common_data: &CommonCircuitData, ) -> Self { - let eval_commitment = |z: F::Extension, c: &PolynomialBatchCommitment| { + let eval_commitment = |z: F::Extension, c: &PolynomialBatchCommitment| { c.polynomials .par_iter() .map(|p| p.to_extension().eval(z)) @@ -313,12 +323,14 @@ mod tests { use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; #[test] fn test_proof_compression() -> Result<()> { - type F = GoldilocksField; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let mut config = CircuitConfig::standard_recursion_config(); config.fri_config.reduction_strategy = FriReductionStrategy::Fixed(vec![1, 1]); @@ -336,7 +348,7 @@ mod tests { let zt = builder.constant(z); let comp_zt = builder.mul(xt, yt); builder.connect(zt, comp_zt); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof.clone(), &data.verifier_only, &data.common)?; diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 031354fa..69f89dea 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -2,14 +2,12 @@ use anyhow::Result; use rayon::prelude::*; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; use crate::fri::commitment::PolynomialBatchCommitment; -use crate::hash::hash_types::HashOut; -use crate::hash::hashing::hash_n_to_hash; use crate::iop::challenger::Challenger; use crate::iop::generator::generate_partial_witness; use crate::iop::witness::{MatrixWitness, PartialWitness, Witness}; use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::plonk_common::PlonkPolynomials; use crate::plonk::plonk_common::ZeroPolyOnCoset; use crate::plonk::proof::{Proof, ProofWithPublicInputs}; @@ -21,12 +19,12 @@ use crate::util::partial_products::partial_products; use crate::util::timing::TimingTree; use crate::util::{log2_ceil, transpose}; -pub(crate) fn prove, const D: usize>( - prover_data: &ProverOnlyCircuitData, - common_data: &CommonCircuitData, +pub(crate) fn prove, C: GenericConfig, const D: usize>( + prover_data: &ProverOnlyCircuitData, + common_data: &CommonCircuitData, inputs: PartialWitness, timing: &mut TimingTree, -) -> Result> { +) -> Result> { let config = &common_data.config; let num_challenges = config.num_challenges; let quotient_degree = common_data.quotient_degree(); @@ -39,7 +37,7 @@ pub(crate) fn prove, const D: usize>( ); let public_inputs = partition_witness.get_targets(&prover_data.public_inputs); - let public_inputs_hash = hash_n_to_hash(public_inputs.clone(), true); + let public_inputs_hash = C::InnerHasher::hash(public_inputs.clone(), true); if cfg!(debug_assertions) { // Display the marked targets for debugging purposes. @@ -80,12 +78,12 @@ pub(crate) fn prove, const D: usize>( let mut challenger = Challenger::new(); // Observe the instance. - challenger.observe_hash(&common_data.circuit_digest); - challenger.observe_hash(&public_inputs_hash); + challenger.observe_hash::(common_data.circuit_digest); + challenger.observe_hash::(public_inputs_hash); challenger.observe_cap(&wires_commitment.merkle_tree.cap); - let betas = challenger.get_n_challenges(num_challenges); - let gammas = challenger.get_n_challenges(num_challenges); + let betas = challenger.get_n_challenges::(num_challenges); + let gammas = challenger.get_n_challenges::(num_challenges); assert!( common_data.quotient_degree_factor < common_data.config.num_routed_wires, @@ -125,7 +123,7 @@ pub(crate) fn prove, const D: usize>( challenger.observe_cap(&zs_partial_products_commitment.merkle_tree.cap); - let alphas = challenger.get_n_challenges(num_challenges); + let alphas = challenger.get_n_challenges::(num_challenges); let quotient_polys = timed!( timing, @@ -175,7 +173,7 @@ pub(crate) fn prove, const D: usize>( challenger.observe_cap("ient_polys_commitment.merkle_tree.cap); - let zeta = challenger.get_extension_challenge(); + let zeta = challenger.get_extension_challenge::(); let (opening_proof, openings) = timed!( timing, @@ -208,12 +206,16 @@ pub(crate) fn prove, const D: usize>( } /// Compute the partial products used in the `Z` polynomials. -fn all_wires_permutation_partial_products, const D: usize>( +fn all_wires_permutation_partial_products< + F: Extendable, + C: GenericConfig, + const D: usize, +>( witness: &MatrixWitness, betas: &[F], gammas: &[F], - prover_data: &ProverOnlyCircuitData, - common_data: &CommonCircuitData, + prover_data: &ProverOnlyCircuitData, + common_data: &CommonCircuitData, ) -> Vec>> { (0..common_data.config.num_challenges) .map(|i| { @@ -231,12 +233,16 @@ fn all_wires_permutation_partial_products, const D: /// Compute the partial products used in the `Z` polynomial. /// Returns the polynomials interpolating `partial_products(f / g)` /// where `f, g` are the products in the definition of `Z`: `Z(g^i) = f / g`. -fn wires_permutation_partial_products, const D: usize>( +fn wires_permutation_partial_products< + F: Extendable, + C: GenericConfig, + const D: usize, +>( witness: &MatrixWitness, beta: F, gamma: F, - prover_data: &ProverOnlyCircuitData, - common_data: &CommonCircuitData, + prover_data: &ProverOnlyCircuitData, + common_data: &CommonCircuitData, ) -> Vec> { let degree = common_data.quotient_degree_factor; let subgroup = &prover_data.subgroup; @@ -285,9 +291,9 @@ fn wires_permutation_partial_products, const D: usi .collect() } -fn compute_zs, const D: usize>( +fn compute_zs, C: GenericConfig, const D: usize>( partial_products: &[Vec>], - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> Vec> { (0..common_data.config.num_challenges) .map(|i| compute_z(&partial_products[i], common_data)) @@ -295,9 +301,9 @@ fn compute_zs, const D: usize>( } /// Compute the `Z` polynomial by reusing the computations done in `wires_permutation_partial_products`. -fn compute_z, const D: usize>( +fn compute_z, C: GenericConfig, const D: usize>( partial_products: &[PolynomialValues], - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> PolynomialValues { let mut plonk_z_points = vec![F::ONE]; for i in 1..common_data.degree() { @@ -310,12 +316,12 @@ fn compute_z, const D: usize>( const BATCH_SIZE: usize = 32; -fn compute_quotient_polys<'a, F: RichField + Extendable, const D: usize>( - common_data: &CommonCircuitData, - prover_data: &'a ProverOnlyCircuitData, - public_inputs_hash: &HashOut, - wires_commitment: &'a PolynomialBatchCommitment, - zs_partial_products_commitment: &'a PolynomialBatchCommitment, +fn compute_quotient_polys<'a, F: Extendable, C: GenericConfig, const D: usize>( + common_data: &CommonCircuitData, + prover_data: &'a ProverOnlyCircuitData, + public_inputs_hash: &<>::InnerHasher as Hasher>::Hash, + wires_commitment: &'a PolynomialBatchCommitment, + zs_partial_products_commitment: &'a PolynomialBatchCommitment, betas: &[F], gammas: &[F], alphas: &[F], @@ -339,7 +345,7 @@ fn compute_quotient_polys<'a, F: RichField + Extendable, const D: usize>( let lde_size = points.len(); // Retrieve the LDE values at index `i`. - let get_at_index = |comm: &'a PolynomialBatchCommitment, i: usize| -> &'a [F] { + let get_at_index = |comm: &'a PolynomialBatchCommitment, i: usize| -> &'a [F] { comm.get_lde_values(i * step) }; diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index ad049275..ab43d256 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -1,23 +1,23 @@ use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; use crate::hash::hash_types::HashOutTarget; use crate::iop::challenger::RecursiveChallenger; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarget}; +use crate::plonk::config::AlgebraicConfig; use crate::plonk::proof::ProofWithPublicInputsTarget; use crate::plonk::vanishing_poly::eval_vanishing_poly_recursively; use crate::plonk::vars::EvaluationTargets; use crate::util::reducing::ReducingFactorTarget; use crate::with_context; -impl, const D: usize> CircuitBuilder { +impl, const D: usize> CircuitBuilder { /// Recursively verifies an inner proof. - pub fn add_recursive_verifier( + pub fn add_recursive_verifier>( &mut self, proof_with_pis: ProofWithPublicInputsTarget, inner_config: &CircuitConfig, inner_verifier_data: &VerifierCircuitTarget, - inner_common_data: &CommonCircuitData, + inner_common_data: &CommonCircuitData, ) { let ProofWithPublicInputsTarget { proof, @@ -27,7 +27,7 @@ impl, const D: usize> CircuitBuilder { let num_challenges = inner_config.num_challenges; - let public_inputs_hash = &self.hash_n_to_hash(public_inputs, true); + let public_inputs_hash = &self.hash_n_to_hash::(public_inputs, true); let mut challenger = RecursiveChallenger::new(self); @@ -127,7 +127,6 @@ mod tests { use log::{info, Level}; use super::*; - use crate::field::goldilocks_field::GoldilocksField; use crate::fri::proof::{ FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget, FriQueryStepTarget, }; @@ -137,6 +136,7 @@ mod tests { use crate::hash::merkle_proofs::MerkleProofTarget; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_data::VerifierOnlyCircuitData; + use crate::plonk::config::{GenericConfig, KeccakGoldilocksConfig, PoseidonGoldilocksConfig}; use crate::plonk::proof::{ CompressedProofWithPublicInputs, OpeningSetTarget, Proof, ProofTarget, ProofWithPublicInputs, @@ -146,8 +146,8 @@ mod tests { use crate::util::timing::TimingTree; // Construct a `FriQueryRoundTarget` with the same dimensions as the ones in `proof`. - fn get_fri_query_round, const D: usize>( - proof: &Proof, + fn get_fri_query_round, C: GenericConfig, const D: usize>( + proof: &Proof, builder: &mut CircuitBuilder, ) -> FriQueryRoundTarget { let mut query_round = FriQueryRoundTarget { @@ -179,8 +179,8 @@ mod tests { } // Construct a `ProofTarget` with the same dimensions as `proof`. - fn proof_to_proof_target, const D: usize>( - proof_with_pis: &ProofWithPublicInputs, + fn proof_to_proof_target, C: GenericConfig, const D: usize>( + proof_with_pis: &ProofWithPublicInputs, builder: &mut CircuitBuilder, ) -> ProofWithPublicInputsTarget { let ProofWithPublicInputs { @@ -240,8 +240,8 @@ mod tests { } // Set the targets in a `ProofTarget` to their corresponding values in a `Proof`. - fn set_proof_target, const D: usize>( - proof: &ProofWithPublicInputs, + fn set_proof_target, C: AlgebraicConfig, const D: usize>( + proof: &ProofWithPublicInputs, pt: &ProofWithPublicInputsTarget, pw: &mut PartialWitness, ) { @@ -364,12 +364,14 @@ mod tests { #[ignore] fn test_recursive_verifier() -> Result<()> { init_logger(); - type F = GoldilocksField; const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; let config = CircuitConfig::standard_recursion_config(); - let (proof, vd, cd) = dummy_proof::(&config, 8_000)?; - let (proof, _vd, cd) = recursive_proof(proof, vd, cd, &config, &config, true, true)?; + let (proof, vd, cd) = dummy_proof::(&config, 8_000)?; + let (proof, _vd, cd) = + recursive_proof::(proof, vd, cd, &config, &config, true, true)?; test_serialization(&proof, &cd)?; Ok(()) @@ -379,14 +381,18 @@ mod tests { #[ignore] fn test_recursive_recursive_verifier() -> Result<()> { init_logger(); - type F = GoldilocksField; const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type KC = KeccakGoldilocksConfig; + type F = >::F; let config = CircuitConfig::standard_recursion_config(); - let (proof, vd, cd) = dummy_proof::(&config, 8_000)?; - let (proof, vd, cd) = recursive_proof(proof, vd, cd, &config, &config, false, false)?; - let (proof, _vd, cd) = recursive_proof(proof, vd, cd, &config, &config, true, true)?; + let (proof, vd, cd) = dummy_proof::(&config, 8_000)?; + let (proof, vd, cd) = + recursive_proof::(proof, vd, cd, &config, &config, false, false)?; + let (proof, _vd, cd) = + recursive_proof::(proof, vd, cd, &config, &config, true, true)?; test_serialization(&proof, &cd)?; @@ -399,13 +405,15 @@ mod tests { #[ignore] fn test_size_optimized_recursion() -> Result<()> { init_logger(); - type F = GoldilocksField; const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type KC = KeccakGoldilocksConfig; + type F = >::F; let standard_config = CircuitConfig::standard_recursion_config(); // A dummy proof with degree 2^13. - let (proof, vd, cd) = dummy_proof::(&standard_config, 8_000)?; + let (proof, vd, cd) = dummy_proof::(&standard_config, 8_000)?; assert_eq!(cd.degree_bits, 13); // A standard recursive proof with degree 2^13. @@ -431,7 +439,7 @@ mod tests { }, ..standard_config }; - let (proof, vd, cd) = recursive_proof( + let (proof, vd, cd) = recursive_proof::( proof, vd, cd, @@ -454,7 +462,7 @@ mod tests { }, ..high_rate_config.clone() }; - let (proof, vd, cd) = recursive_proof( + let (proof, vd, cd) = recursive_proof::( proof, vd, cd, @@ -475,7 +483,7 @@ mod tests { }, ..higher_rate_more_routing_config }; - let (proof, _vd, cd) = recursive_proof( + let (proof, _vd, cd) = recursive_proof::( proof, vd, cd, @@ -492,13 +500,13 @@ mod tests { } /// Creates a dummy proof which should have roughly `num_dummy_gates` gates. - fn dummy_proof, const D: usize>( + fn dummy_proof, C: GenericConfig, const D: usize>( config: &CircuitConfig, num_dummy_gates: u64, ) -> Result<( - ProofWithPublicInputs, - VerifierOnlyCircuitData, - CommonCircuitData, + ProofWithPublicInputs, + VerifierOnlyCircuitData, + CommonCircuitData, )> { let mut builder = CircuitBuilder::::new(config.clone()); let input = builder.add_virtual_target(); @@ -508,7 +516,7 @@ mod tests { builder.arithmetic(i_f, i_f, input, input, input); } - let data = builder.build(); + let data = builder.build::(); let mut inputs = PartialWitness::new(); inputs.set_target(input, F::ZERO); let proof = data.prove(inputs)?; @@ -517,18 +525,23 @@ mod tests { Ok((proof, data.verifier_only, data.common)) } - fn recursive_proof, const D: usize>( - inner_proof: ProofWithPublicInputs, - inner_vd: VerifierOnlyCircuitData, - inner_cd: CommonCircuitData, + fn recursive_proof< + F: Extendable, + C: GenericConfig, + InnerC: AlgebraicConfig, + const D: usize, + >( + inner_proof: ProofWithPublicInputs, + inner_vd: VerifierOnlyCircuitData, + inner_cd: CommonCircuitData, inner_config: &CircuitConfig, config: &CircuitConfig, print_gate_counts: bool, print_timing: bool, ) -> Result<( - ProofWithPublicInputs, - VerifierOnlyCircuitData, - CommonCircuitData, + ProofWithPublicInputs, + VerifierOnlyCircuitData, + CommonCircuitData, )> { let mut builder = CircuitBuilder::::new(config.clone()); let mut pw = PartialWitness::new(); @@ -549,7 +562,7 @@ mod tests { builder.print_gate_counts(0); } - let data = builder.build(); + let data = builder.build::(); let mut timing = TimingTree::new("prove", Level::Debug); let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; @@ -563,9 +576,9 @@ mod tests { } /// Test serialization and print some size info. - fn test_serialization, const D: usize>( - proof: &ProofWithPublicInputs, - cd: &CommonCircuitData, + fn test_serialization, C: GenericConfig, const D: usize>( + proof: &ProofWithPublicInputs, + cd: &CommonCircuitData, ) -> Result<()> { let proof_bytes = proof.to_bytes()?; info!("Proof length: {} bytes", proof_bytes.len()); diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 0ebdbd22..7e4af676 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -5,6 +5,7 @@ use crate::gates::gate::PrefixedGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CommonCircuitData; +use crate::plonk::config::GenericConfig; use crate::plonk::plonk_common; use crate::plonk::plonk_common::{eval_l_1_recursively, ZeroPolyOnCoset}; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; @@ -15,8 +16,8 @@ use crate::with_context; /// Evaluate the vanishing polynomial at `x`. In this context, the vanishing polynomial is a random /// linear combination of gate constraints, plus some other terms relating to the permutation /// argument. All such terms should vanish on `H`. -pub(crate) fn eval_vanishing_poly, const D: usize>( - common_data: &CommonCircuitData, +pub(crate) fn eval_vanishing_poly, C: GenericConfig, const D: usize>( + common_data: &CommonCircuitData, x: F::Extension, vars: EvaluationVars, local_zs: &[F::Extension], @@ -102,8 +103,12 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( } /// Like `eval_vanishing_poly`, but specialized for base field points. Batched. -pub(crate) fn eval_vanishing_poly_base_batch, const D: usize>( - common_data: &CommonCircuitData, +pub(crate) fn eval_vanishing_poly_base_batch< + F: Extendable, + C: GenericConfig, + const D: usize, +>( + common_data: &CommonCircuitData, indices_batch: &[usize], xs_batch: &[F], vars_batch: &[EvaluationVarsBase], @@ -314,9 +319,13 @@ pub fn evaluate_gate_constraints_recursively, const /// /// Assumes `x != 1`; if `x` could be 1 then this is unsound. This is fine if `x` is a random /// variable drawn from a sufficiently large domain. -pub(crate) fn eval_vanishing_poly_recursively, const D: usize>( +pub(crate) fn eval_vanishing_poly_recursively< + F: Extendable, + C: GenericConfig, + const D: usize, +>( builder: &mut CircuitBuilder, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, x: ExtensionTarget, x_pow_deg: ExtensionTarget, vars: EvaluationTargets, diff --git a/src/plonk/verifier.rs b/src/plonk/verifier.rs index e21fe328..cf27d26f 100644 --- a/src/plonk/verifier.rs +++ b/src/plonk/verifier.rs @@ -1,28 +1,33 @@ use anyhow::{ensure, Result}; use crate::field::extension_field::Extendable; -use crate::field::field_types::{Field, RichField}; +use crate::field::field_types::Field; use crate::fri::verifier::verify_fri_proof; use crate::plonk::circuit_data::{CommonCircuitData, VerifierOnlyCircuitData}; +use crate::plonk::config::GenericConfig; use crate::plonk::plonk_common::reduce_with_powers; use crate::plonk::proof::{ProofChallenges, ProofWithPublicInputs}; use crate::plonk::vanishing_poly::eval_vanishing_poly; use crate::plonk::vars::EvaluationVars; -pub(crate) fn verify, const D: usize>( - proof_with_pis: ProofWithPublicInputs, - verifier_data: &VerifierOnlyCircuitData, - common_data: &CommonCircuitData, +pub(crate) fn verify, C: GenericConfig, const D: usize>( + proof_with_pis: ProofWithPublicInputs, + verifier_data: &VerifierOnlyCircuitData, + common_data: &CommonCircuitData, ) -> Result<()> { let challenges = proof_with_pis.get_challenges(common_data)?; verify_with_challenges(proof_with_pis, challenges, verifier_data, common_data) } -pub(crate) fn verify_with_challenges, const D: usize>( - proof_with_pis: ProofWithPublicInputs, +pub(crate) fn verify_with_challenges< + F: Extendable, + C: GenericConfig, + const D: usize, +>( + proof_with_pis: ProofWithPublicInputs, challenges: ProofChallenges, - verifier_data: &VerifierOnlyCircuitData, - common_data: &CommonCircuitData, + verifier_data: &VerifierOnlyCircuitData, + common_data: &CommonCircuitData, ) -> Result<()> { let public_inputs_hash = &proof_with_pis.get_public_inputs_hash(); diff --git a/src/util/reducing.rs b/src/util/reducing.rs index 3e00602c..89c02b51 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -185,12 +185,14 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; fn test_reduce_gadget_base(n: usize) -> Result<()> { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; let config = CircuitConfig::standard_recursion_config(); @@ -209,16 +211,17 @@ mod tests { builder.connect_extension(manual_reduce, circuit_reduce); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) } fn test_reduce_gadget(n: usize) -> Result<()> { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; let config = CircuitConfig::standard_recursion_config(); @@ -240,7 +243,7 @@ mod tests { builder.connect_extension(manual_reduce, circuit_reduce); - let data = builder.build(); + let data = builder.build::(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) diff --git a/src/util/serialization.rs b/src/util/serialization.rs index 172b4d67..053b8c68 100644 --- a/src/util/serialization.rs +++ b/src/util/serialization.rs @@ -10,10 +10,10 @@ use crate::fri::proof::{ CompressedFriProof, CompressedFriQueryRounds, FriInitialTreeProof, FriProof, FriQueryRound, FriQueryStep, }; -use crate::hash::hash_types::HashOut; use crate::hash::merkle_proofs::MerkleProof; use crate::hash::merkle_tree::MerkleCap; use crate::plonk::circuit_data::CommonCircuitData; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::proof::{ CompressedProof, CompressedProofWithPublicInputs, OpeningSet, Proof, ProofWithPublicInputs, }; @@ -80,42 +80,45 @@ impl Buffer { )) } - fn write_hash(&mut self, h: HashOut) -> Result<()> { - for &a in &h.elements { - self.write_field(a)?; - } - Ok(()) - } - fn read_hash(&mut self) -> Result> { - let mut elements = [F::ZERO; 4]; - for a in elements.iter_mut() { - *a = self.read_field()?; - } - Ok(HashOut { elements }) + fn write_hash>(&mut self, h: H::Hash) -> Result<()> { + let bytes: Vec = h.into(); + self.0.write_all(&bytes) } - fn write_merkle_cap(&mut self, cap: &MerkleCap) -> Result<()> { + fn read_hash>(&mut self) -> Result { + let mut buf = vec![0; H::HASH_SIZE]; + self.0.read_exact(&mut buf)?; + Ok(H::Hash::from(buf.to_vec())) + } + + fn write_merkle_cap>( + &mut self, + cap: &MerkleCap, + ) -> Result<()> { for &a in &cap.0 { - self.write_hash(a)?; + self.write_hash::(a)?; } Ok(()) } - fn read_merkle_cap(&mut self, cap_height: usize) -> Result> { + fn read_merkle_cap>( + &mut self, + cap_height: usize, + ) -> Result> { let cap_length = 1 << cap_height; Ok(MerkleCap( (0..cap_length) - .map(|_| self.read_hash()) + .map(|_| self.read_hash::()) .collect::>>()?, )) } - fn write_field_vec(&mut self, v: &[F]) -> Result<()> { + pub fn write_field_vec(&mut self, v: &[F]) -> Result<()> { for &a in v { self.write_field(a)?; } Ok(()) } - fn read_field_vec(&mut self, length: usize) -> Result> { + pub fn read_field_vec(&mut self, length: usize) -> Result> { (0..length) .map(|_| self.read_field()) .collect::>>() @@ -151,9 +154,9 @@ impl Buffer { self.write_field_ext_vec::(&os.partial_products)?; self.write_field_ext_vec::(&os.quotient_polys) } - fn read_opening_set, const D: usize>( + fn read_opening_set, C: GenericConfig, const D: usize>( &mut self, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> Result> { let config = &common_data.config; let constants = self.read_field_ext_vec::(common_data.num_constants)?; @@ -178,7 +181,10 @@ impl Buffer { }) } - fn write_merkle_proof(&mut self, p: &MerkleProof) -> Result<()> { + fn write_merkle_proof>( + &mut self, + p: &MerkleProof, + ) -> Result<()> { let length = p.siblings.len(); self.write_u8( length @@ -186,22 +192,22 @@ impl Buffer { .expect("Merkle proof length must fit in u8."), )?; for &h in &p.siblings { - self.write_hash(h)?; + self.write_hash::(h)?; } Ok(()) } - fn read_merkle_proof(&mut self) -> Result> { + fn read_merkle_proof>(&mut self) -> Result> { let length = self.read_u8()?; Ok(MerkleProof { siblings: (0..length) - .map(|_| self.read_hash()) + .map(|_| self.read_hash::()) .collect::>>()?, }) } - fn write_fri_initial_proof( + fn write_fri_initial_proof, C: GenericConfig, const D: usize>( &mut self, - fitp: &FriInitialTreeProof, + fitp: &FriInitialTreeProof, ) -> Result<()> { for (v, p) in &fitp.evals_proofs { self.write_field_vec(v)?; @@ -209,10 +215,10 @@ impl Buffer { } Ok(()) } - fn read_fri_initial_proof, const D: usize>( + fn read_fri_initial_proof, C: GenericConfig, const D: usize>( &mut self, - common_data: &CommonCircuitData, - ) -> Result> { + common_data: &CommonCircuitData, + ) -> Result> { let config = &common_data.config; let mut evals_proofs = Vec::with_capacity(4); @@ -238,18 +244,18 @@ impl Buffer { Ok(FriInitialTreeProof { evals_proofs }) } - fn write_fri_query_step, const D: usize>( + fn write_fri_query_step, C: GenericConfig, const D: usize>( &mut self, - fqs: &FriQueryStep, + fqs: &FriQueryStep, ) -> Result<()> { self.write_field_ext_vec::(&fqs.evals)?; self.write_merkle_proof(&fqs.merkle_proof) } - fn read_fri_query_step, const D: usize>( + fn read_fri_query_step, C: GenericConfig, const D: usize>( &mut self, arity: usize, compressed: bool, - ) -> Result> { + ) -> Result> { let evals = self.read_field_ext_vec::(arity - if compressed { 1 } else { 0 })?; let merkle_proof = self.read_merkle_proof()?; Ok(FriQueryStep { @@ -258,22 +264,22 @@ impl Buffer { }) } - fn write_fri_query_rounds, const D: usize>( + fn write_fri_query_rounds, C: GenericConfig, const D: usize>( &mut self, - fqrs: &[FriQueryRound], + fqrs: &[FriQueryRound], ) -> Result<()> { for fqr in fqrs { - self.write_fri_initial_proof(&fqr.initial_trees_proof)?; + self.write_fri_initial_proof::(&fqr.initial_trees_proof)?; for fqs in &fqr.steps { - self.write_fri_query_step(fqs)?; + self.write_fri_query_step::(fqs)?; } } Ok(()) } - fn read_fri_query_rounds, const D: usize>( + fn read_fri_query_rounds, C: GenericConfig, const D: usize>( &mut self, - common_data: &CommonCircuitData, - ) -> Result>> { + common_data: &CommonCircuitData, + ) -> Result>> { let config = &common_data.config; let mut fqrs = Vec::with_capacity(config.fri_config.num_query_rounds); for _ in 0..config.fri_config.num_query_rounds { @@ -282,7 +288,7 @@ impl Buffer { .fri_params .reduction_arity_bits .iter() - .map(|&ar| self.read_fri_query_step(1 << ar, false)) + .map(|&ar| self.read_fri_query_step::(1 << ar, false)) .collect::>()?; fqrs.push(FriQueryRound { initial_trees_proof, @@ -292,21 +298,21 @@ impl Buffer { Ok(fqrs) } - fn write_fri_proof, const D: usize>( + fn write_fri_proof, C: GenericConfig, const D: usize>( &mut self, - fp: &FriProof, + fp: &FriProof, ) -> Result<()> { for cap in &fp.commit_phase_merkle_caps { self.write_merkle_cap(cap)?; } - self.write_fri_query_rounds(&fp.query_round_proofs)?; + self.write_fri_query_rounds::(&fp.query_round_proofs)?; self.write_field_ext_vec::(&fp.final_poly.coeffs)?; self.write_field(fp.pow_witness) } - fn read_fri_proof, const D: usize>( + fn read_fri_proof, C: GenericConfig, const D: usize>( &mut self, - common_data: &CommonCircuitData, - ) -> Result> { + common_data: &CommonCircuitData, + ) -> Result> { let config = &common_data.config; let commit_phase_merkle_caps = (0..common_data.fri_params.reduction_arity_bits.len()) .map(|_| self.read_merkle_cap(config.cap_height)) @@ -323,20 +329,20 @@ impl Buffer { }) } - pub fn write_proof, const D: usize>( + pub fn write_proof, C: GenericConfig, const D: usize>( &mut self, - proof: &Proof, + proof: &Proof, ) -> Result<()> { self.write_merkle_cap(&proof.wires_cap)?; self.write_merkle_cap(&proof.plonk_zs_partial_products_cap)?; self.write_merkle_cap(&proof.quotient_polys_cap)?; self.write_opening_set(&proof.openings)?; - self.write_fri_proof(&proof.opening_proof) + self.write_fri_proof::(&proof.opening_proof) } - pub fn read_proof, const D: usize>( + pub fn read_proof, C: GenericConfig, const D: usize>( &mut self, - common_data: &CommonCircuitData, - ) -> Result> { + common_data: &CommonCircuitData, + ) -> Result> { let config = &common_data.config; let wires_cap = self.read_merkle_cap(config.cap_height)?; let plonk_zs_partial_products_cap = self.read_merkle_cap(config.cap_height)?; @@ -353,9 +359,13 @@ impl Buffer { }) } - pub fn write_proof_with_public_inputs, const D: usize>( + pub fn write_proof_with_public_inputs< + F: Extendable, + C: GenericConfig, + const D: usize, + >( &mut self, - proof_with_pis: &ProofWithPublicInputs, + proof_with_pis: &ProofWithPublicInputs, ) -> Result<()> { let ProofWithPublicInputs { proof, @@ -364,10 +374,14 @@ impl Buffer { self.write_proof(proof)?; self.write_field_vec(public_inputs) } - pub fn read_proof_with_public_inputs, const D: usize>( + pub fn read_proof_with_public_inputs< + F: Extendable, + C: GenericConfig, + const D: usize, + >( &mut self, - common_data: &CommonCircuitData, - ) -> Result> { + common_data: &CommonCircuitData, + ) -> Result> { let proof = self.read_proof(common_data)?; let public_inputs = self.read_field_vec( (self.len() - self.0.position() as usize) / std::mem::size_of::(), @@ -379,9 +393,13 @@ impl Buffer { }) } - fn write_compressed_fri_query_rounds, const D: usize>( + fn write_compressed_fri_query_rounds< + F: Extendable, + C: GenericConfig, + const D: usize, + >( &mut self, - cfqrs: &CompressedFriQueryRounds, + cfqrs: &CompressedFriQueryRounds, ) -> Result<()> { for &i in &cfqrs.indices { self.write_u32(i as u32)?; @@ -390,21 +408,25 @@ impl Buffer { let mut initial_trees_proofs = cfqrs.initial_trees_proofs.iter().collect::>(); initial_trees_proofs.sort_by_key(|&x| x.0); for (_, itp) in initial_trees_proofs { - self.write_fri_initial_proof(itp)?; + self.write_fri_initial_proof::(itp)?; } for h in &cfqrs.steps { let mut fri_query_steps = h.iter().collect::>(); fri_query_steps.sort_by_key(|&x| x.0); for (_, fqs) in fri_query_steps { - self.write_fri_query_step(fqs)?; + self.write_fri_query_step::(fqs)?; } } Ok(()) } - fn read_compressed_fri_query_rounds, const D: usize>( + fn read_compressed_fri_query_rounds< + F: Extendable, + C: GenericConfig, + const D: usize, + >( &mut self, - common_data: &CommonCircuitData, - ) -> Result> { + common_data: &CommonCircuitData, + ) -> Result> { let config = &common_data.config; let original_indices = (0..config.fri_config.num_query_rounds) .map(|_| self.read_u32().map(|i| i as usize)) @@ -425,7 +447,7 @@ impl Buffer { }); indices.dedup(); let query_steps = (0..indices.len()) - .map(|_| self.read_fri_query_step(1 << a, true)) + .map(|_| self.read_fri_query_step::(1 << a, true)) .collect::>>()?; steps.push( indices @@ -443,21 +465,21 @@ impl Buffer { }) } - fn write_compressed_fri_proof, const D: usize>( + fn write_compressed_fri_proof, C: GenericConfig, const D: usize>( &mut self, - fp: &CompressedFriProof, + fp: &CompressedFriProof, ) -> Result<()> { for cap in &fp.commit_phase_merkle_caps { self.write_merkle_cap(cap)?; } - self.write_compressed_fri_query_rounds(&fp.query_round_proofs)?; + self.write_compressed_fri_query_rounds::(&fp.query_round_proofs)?; self.write_field_ext_vec::(&fp.final_poly.coeffs)?; self.write_field(fp.pow_witness) } - fn read_compressed_fri_proof, const D: usize>( + fn read_compressed_fri_proof, C: GenericConfig, const D: usize>( &mut self, - common_data: &CommonCircuitData, - ) -> Result> { + common_data: &CommonCircuitData, + ) -> Result> { let config = &common_data.config; let commit_phase_merkle_caps = (0..common_data.fri_params.reduction_arity_bits.len()) .map(|_| self.read_merkle_cap(config.cap_height)) @@ -474,20 +496,20 @@ impl Buffer { }) } - pub fn write_compressed_proof, const D: usize>( + pub fn write_compressed_proof, C: GenericConfig, const D: usize>( &mut self, - proof: &CompressedProof, + proof: &CompressedProof, ) -> Result<()> { self.write_merkle_cap(&proof.wires_cap)?; self.write_merkle_cap(&proof.plonk_zs_partial_products_cap)?; self.write_merkle_cap(&proof.quotient_polys_cap)?; self.write_opening_set(&proof.openings)?; - self.write_compressed_fri_proof(&proof.opening_proof) + self.write_compressed_fri_proof::(&proof.opening_proof) } - pub fn read_compressed_proof, const D: usize>( + pub fn read_compressed_proof, C: GenericConfig, const D: usize>( &mut self, - common_data: &CommonCircuitData, - ) -> Result> { + common_data: &CommonCircuitData, + ) -> Result> { let config = &common_data.config; let wires_cap = self.read_merkle_cap(config.cap_height)?; let plonk_zs_partial_products_cap = self.read_merkle_cap(config.cap_height)?; @@ -505,11 +527,12 @@ impl Buffer { } pub fn write_compressed_proof_with_public_inputs< - F: RichField + Extendable, + F: Extendable, + C: GenericConfig, const D: usize, >( &mut self, - proof_with_pis: &CompressedProofWithPublicInputs, + proof_with_pis: &CompressedProofWithPublicInputs, ) -> Result<()> { let CompressedProofWithPublicInputs { proof, @@ -519,12 +542,13 @@ impl Buffer { self.write_field_vec(public_inputs) } pub fn read_compressed_proof_with_public_inputs< - F: RichField + Extendable, + F: Extendable, + C: GenericConfig, const D: usize, >( &mut self, - common_data: &CommonCircuitData, - ) -> Result> { + common_data: &CommonCircuitData, + ) -> Result> { let proof = self.read_compressed_proof(common_data)?; let public_inputs = self.read_field_vec( (self.len() - self.0.position() as usize) / std::mem::size_of::(),