diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index 8867415d..8e798f72 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -50,8 +50,8 @@ fn bench_prove, const D: usize>() -> Result<()> { let circuit = builder.build(); let inputs = PartialWitness::new(); - let proof = circuit.prove(inputs)?; - let proof_bytes = serde_cbor::to_vec(&proof).unwrap(); + let proof_with_pis = circuit.prove(inputs)?; + let proof_bytes = serde_cbor::to_vec(&proof_with_pis).unwrap(); info!("Proof length: {} bytes", proof_bytes.len()); - circuit.verify(proof) + circuit.verify(proof_with_pis) } diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 2cfeb720..2e12ce0f 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -17,6 +17,7 @@ use crate::gates::constant::ConstantGate; use crate::gates::gate::{GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; use crate::gates::noop::NoopGate; +use crate::gates::public_input::PublicInputGate; use crate::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator}; use crate::hash::hash_n_to_hash; use crate::permutation_argument::TargetPartition; @@ -39,8 +40,8 @@ pub struct CircuitBuilder, const D: usize> { /// The concrete placement of each gate. gate_instances: Vec>, - /// The next available index for a public input. - public_input_index: usize, + /// Targets to be made public. + public_inputs: Vec, /// The next available index for a `VirtualTarget`. virtual_target_index: usize, @@ -66,7 +67,7 @@ impl, const D: usize> CircuitBuilder { config, gates: HashSet::new(), gate_instances: Vec::new(), - public_input_index: 0, + public_inputs: Vec::new(), virtual_target_index: 0, copy_constraints: Vec::new(), context_log: ContextTree::new(), @@ -81,14 +82,14 @@ impl, const D: usize> CircuitBuilder { self.gate_instances.len() } - pub fn add_public_input(&mut self) -> Target { - let index = self.public_input_index; - self.public_input_index += 1; - Target::PublicInput { index } + /// Registers the given target as a public input. + pub fn register_public_input(&mut self, target: Target) { + self.public_inputs.push(target); } - pub fn add_public_inputs(&mut self, n: usize) -> Vec { - (0..n).map(|_i| self.add_public_input()).collect() + /// Registers the given targets as public inputs. + pub fn register_public_inputs(&mut self, targets: &[Target]) { + targets.iter().for_each(|&t| self.register_public_input(t)); } /// Adds a new "virtual" target. This is not an actual wire in the witness, but just a target @@ -462,10 +463,7 @@ impl, const D: usize> CircuitBuilder { let degree_log = log2_strict(degree); let mut target_partition = TargetPartition::new(|t| match t { Target::Wire(Wire { gate, input }) => gate * self.config.num_routed_wires + input, - Target::PublicInput { index } => degree * self.config.num_routed_wires + index, - Target::VirtualTarget { index } => { - degree * self.config.num_routed_wires + self.public_input_index + index - } + Target::VirtualTarget { index } => degree * self.config.num_routed_wires + index, }); for gate in 0..degree { @@ -474,10 +472,6 @@ impl, const D: usize> CircuitBuilder { } } - for index in 0..self.public_input_index { - target_partition.add(Target::PublicInput { index }); - } - for index in 0..self.virtual_target_index { target_partition.add(Target::VirtualTarget { index }); } @@ -500,6 +494,19 @@ impl, const D: usize> CircuitBuilder { pub fn build(mut self) -> CircuitData { let quotient_degree_factor = 7; // TODO: add this as a parameter. let start = Instant::now(); + + // Hash the public inputs, and route them to a `PublicInputGate` which will enforce that + // those hash wires match the claimed public inputs. + let public_inputs_hash = self.hash_n_to_hash(self.public_inputs.clone(), true); + let pi_gate = self.add_gate_no_constants(PublicInputGate::get()); + for (&hash_part, wire) in public_inputs_hash + .elements + .iter() + .zip(PublicInputGate::wires_public_inputs_hash()) + { + self.route(hash_part, Target::wire(pi_gate, wire)) + } + info!( "Degree before blinding & padding: {}", self.gate_instances.len() @@ -552,6 +559,7 @@ impl, const D: usize> CircuitBuilder { subgroup, copy_constraints: self.copy_constraints, gate_instances: self.gate_instances, + public_inputs: self.public_inputs, marked_targets: self.marked_targets, }; diff --git a/src/circuit_data.rs b/src/circuit_data.rs index 9c68ca46..912f05cc 100644 --- a/src/circuit_data.rs +++ b/src/circuit_data.rs @@ -9,8 +9,9 @@ use crate::fri::FriConfig; use crate::gates::gate::{GateInstance, PrefixedGate}; use crate::generator::WitnessGenerator; use crate::polynomial::commitment::ListPolynomialCommitment; -use crate::proof::{Hash, HashTarget, Proof}; +use crate::proof::{Hash, HashTarget, ProofWithPublicInputs}; use crate::prover::prove; +use crate::target::Target; use crate::util::marking::MarkedTargets; use crate::verifier::verify; use crate::witness::PartialWitness; @@ -78,12 +79,12 @@ pub struct CircuitData, const D: usize> { } impl, const D: usize> CircuitData { - pub fn prove(&self, inputs: PartialWitness) -> Result> { + pub fn prove(&self, inputs: PartialWitness) -> Result> { prove(&self.prover_only, &self.common, inputs) } - pub fn verify(&self, proof: Proof) -> Result<()> { - verify(proof, &self.verifier_only, &self.common) + pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { + verify(proof_with_pis, &self.verifier_only, &self.common) } } @@ -100,7 +101,7 @@ pub struct ProverCircuitData, const D: usize> { } impl, const D: usize> ProverCircuitData { - pub fn prove(&self, inputs: PartialWitness) -> Result> { + pub fn prove(&self, inputs: PartialWitness) -> Result> { prove(&self.prover_only, &self.common, inputs) } } @@ -112,8 +113,8 @@ pub struct VerifierCircuitData, const D: usize> { } impl, const D: usize> VerifierCircuitData { - pub fn verify(&self, proof: Proof) -> Result<()> { - verify(proof, &self.verifier_only, &self.common) + pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { + verify(proof_with_pis, &self.verifier_only, &self.common) } } @@ -130,6 +131,8 @@ pub(crate) struct ProverOnlyCircuitData, const D: usize> { pub copy_constraints: Vec, /// The concrete placement of each gate in the circuit. pub gate_instances: Vec>, + /// Targets to be made public. + pub public_inputs: Vec, /// A vector of marked targets. The values assigned to these targets will be displayed by the prover. pub marked_targets: Vec>, } diff --git a/src/field/fft.rs b/src/field/fft.rs index 35b11852..3fd37b1c 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -67,7 +67,7 @@ fn fft_unrolled_root_table(n: usize) -> FftRootTable { #[inline] fn fft_dispatch( - input: Vec, + input: &[F], zero_factor: Option, root_table: Option>, ) -> Vec { @@ -87,13 +87,13 @@ fn fft_dispatch( } #[inline] -pub fn fft(poly: PolynomialCoeffs) -> PolynomialValues { +pub fn fft(poly: &PolynomialCoeffs) -> PolynomialValues { fft_with_options(poly, None, None) } #[inline] pub fn fft_with_options( - poly: PolynomialCoeffs, + poly: &PolynomialCoeffs, zero_factor: Option, root_table: Option>, ) -> PolynomialValues { @@ -104,12 +104,12 @@ pub fn fft_with_options( } #[inline] -pub fn ifft(poly: PolynomialValues) -> PolynomialCoeffs { +pub fn ifft(poly: &PolynomialValues) -> PolynomialCoeffs { ifft_with_options(poly, None, None) } pub fn ifft_with_options( - poly: PolynomialValues, + poly: &PolynomialValues, zero_factor: Option, root_table: Option>, ) -> PolynomialCoeffs { @@ -139,11 +139,7 @@ pub fn ifft_with_options( /// The parameter r signifies that the first 1/2^r of the entries of /// input may be non-zero, but the last 1 - 1/2^r entries are /// definitely zero. -pub(crate) fn fft_classic( - input: Vec, - r: usize, - root_table: FftRootTable, -) -> Vec { +pub(crate) fn fft_classic(input: &[F], r: usize, root_table: FftRootTable) -> Vec { let mut values = reverse_index_bits(input); let n = values.len(); @@ -196,7 +192,7 @@ pub(crate) fn fft_classic( /// The parameter r signifies that the first 1/2^r of the entries of /// input may be non-zero, but the last 1 - 1/2^r entries are /// definitely zero. -fn fft_unrolled(input: Vec, r_orig: usize, root_table: FftRootTable) -> Vec { +fn fft_unrolled(input: &[F], r_orig: usize, root_table: FftRootTable) -> Vec { let n = input.len(); let lg_n = log2_strict(input.len()); @@ -325,10 +321,10 @@ mod tests { } let coefficients = PolynomialCoeffs::new_padded(coefficients); - let points = fft(coefficients.clone()); + let points = fft(&coefficients); assert_eq!(points, evaluate_naive(&coefficients)); - let interpolated_coefficients = ifft(points); + let interpolated_coefficients = ifft(&points); for i in 0..degree { assert_eq!(interpolated_coefficients.coeffs[i], coefficients.coeffs[i]); } @@ -337,12 +333,9 @@ mod tests { } for r in 0..4 { - // expand ceofficients by factor 2^r by filling with zeros - let zero_tail = coefficients.clone().lde(r); - assert_eq!( - fft(zero_tail.clone()), - fft_with_options(zero_tail, Some(r), None) - ); + // expand coefficients by factor 2^r by filling with zeros + let zero_tail = coefficients.lde(r); + assert_eq!(fft(&zero_tail), fft_with_options(&zero_tail, Some(r), None)); } } @@ -350,10 +343,7 @@ mod tests { let degree = coefficients.len(); let degree_padded = 1 << log2_ceil(degree); - let mut coefficients_padded = coefficients.clone(); - for _i in degree..degree_padded { - coefficients_padded.coeffs.push(F::ZERO); - } + let coefficients_padded = coefficients.padded(degree_padded); evaluate_naive_power_of_2(&coefficients_padded) } diff --git a/src/field/interpolation.rs b/src/field/interpolation.rs index 3d5e609c..83414f1b 100644 --- a/src/field/interpolation.rs +++ b/src/field/interpolation.rs @@ -18,7 +18,7 @@ pub(crate) fn interpolant(points: &[(F, F)]) -> PolynomialCoeffs { .map(|x| interpolate(points, x, &barycentric_weights)) .collect(); - let mut coeffs = ifft(PolynomialValues { + let mut coeffs = ifft(&PolynomialValues { values: subgroup_evals, }); coeffs.trim(); diff --git a/src/fri/prover.rs b/src/fri/prover.rs index 1c642d17..5a8f09e5 100644 --- a/src/fri/prover.rs +++ b/src/fri/prover.rs @@ -16,9 +16,9 @@ use crate::util::reverse_index_bits_in_place; pub fn fri_proof, const D: usize>( initial_merkle_trees: &[&MerkleTree], // Coefficients of the polynomial on which the LDT is performed. Only the first `1/rate` coefficients are non-zero. - lde_polynomial_coeffs: &PolynomialCoeffs, + lde_polynomial_coeffs: PolynomialCoeffs, // Evaluation of the polynomial on the large domain. - lde_polynomial_values: &PolynomialValues, + lde_polynomial_values: PolynomialValues, challenger: &mut Challenger, config: &FriConfig, ) -> FriProof { @@ -53,14 +53,11 @@ pub fn fri_proof, const D: usize>( } fn fri_committed_trees, const D: usize>( - polynomial_coeffs: &PolynomialCoeffs, - polynomial_values: &PolynomialValues, + mut coeffs: PolynomialCoeffs, + mut values: PolynomialValues, challenger: &mut Challenger, config: &FriConfig, ) -> (Vec>, PolynomialCoeffs) { - let mut values = polynomial_values.clone(); - let mut coeffs = polynomial_coeffs.clone(); - let mut trees = Vec::new(); let mut shift = F::MULTIPLICATIVE_GROUP_GENERATOR; @@ -91,8 +88,7 @@ fn fri_committed_trees, const D: usize>( .collect::>(), ); shift = shift.exp_u32(arity as u32); - // TODO: Is it faster to interpolate? - values = coeffs.clone().coset_fft(shift.into()) + values = coeffs.coset_fft(shift.into()) } coeffs.trim(); diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index 74c781a7..206a87e5 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -125,14 +125,16 @@ impl, const D: usize> CircuitBuilder { "Number of reductions should be non-zero." ); + let precomputed_reduced_evals = + PrecomputedReducedEvalsTarget::from_os_and_alpha(os, alpha, self); for (i, round_proof) in proof.query_round_proofs.iter().enumerate() { context!( self, &format!("verify {}'th FRI query", i), self.fri_verifier_query_round( - os, zeta, alpha, + precomputed_reduced_evals, initial_merkle_roots, proof, challenger, @@ -169,9 +171,9 @@ impl, const D: usize> CircuitBuilder { &mut self, proof: &FriInitialTreeProofTarget, alpha: ExtensionTarget, - os: &OpeningSetTarget, zeta: ExtensionTarget, subgroup_x: Target, + precomputed_reduced_evals: PrecomputedReducedEvalsTarget, common_data: &CommonCircuitData, ) -> ExtensionTarget { assert!(D > 1, "Not implemented for D=1."); @@ -199,19 +201,9 @@ impl, const D: usize> CircuitBuilder { ) .map(|&e| self.convert_to_ext(e)) .collect::>(); - let single_openings = os - .constants - .iter() - .chain(&os.plonk_sigmas) - .chain(&os.wires) - .chain(&os.quotient_polys) - .chain(&os.partial_products) - .copied() - .collect::>(); - let mut single_numerator = alpha.reduce(&single_evals, self); - // TODO: Precompute the rhs as it is the same in all FRI rounds. - let rhs = alpha.reduce(&single_openings, self); - single_numerator = self.sub_extension(single_numerator, rhs); + let single_composition_eval = alpha.reduce(&single_evals, self); + let single_numerator = + self.sub_extension(single_composition_eval, precomputed_reduced_evals.single); let single_denominator = self.sub_extension(subgroup_x, zeta); let quotient = self.div_unsafe_extension(single_numerator, single_denominator); sum = self.add_extension(sum, quotient); @@ -224,14 +216,15 @@ impl, const D: usize> CircuitBuilder { .take(common_data.zs_range().end) .map(|&e| self.convert_to_ext(e)) .collect::>(); - let zs_composition_eval = alpha.clone().reduce(&zs_evals, self); + let zs_composition_eval = alpha.reduce(&zs_evals, self); let g = self.constant_extension(F::Extension::primitive_root_of_unity(degree_log)); let zeta_right = self.mul_extension(g, zeta); - let zs_ev_zeta = alpha.clone().reduce(&os.plonk_zs, self); - let zs_ev_zeta_right = alpha.reduce(&os.plonk_zs_right, self); let interpol_val = self.interpolate2( - [(zeta, zs_ev_zeta), (zeta_right, zs_ev_zeta_right)], + [ + (zeta, precomputed_reduced_evals.zs), + (zeta_right, precomputed_reduced_evals.zs_right), + ], subgroup_x, ); let zs_numerator = self.sub_extension(zs_composition_eval, interpol_val); @@ -247,9 +240,9 @@ impl, const D: usize> CircuitBuilder { fn fri_verifier_query_round( &mut self, - os: &OpeningSetTarget, zeta: ExtensionTarget, alpha: ExtensionTarget, + precomputed_reduced_evals: PrecomputedReducedEvalsTarget, initial_merkle_roots: &[HashTarget], proof: &FriProofTarget, challenger: &mut RecursiveChallenger, @@ -260,7 +253,6 @@ impl, const D: usize> CircuitBuilder { ) { let config = &common_data.config.fri_config; let n_log = log2_strict(n); - let mut evaluations: Vec>> = Vec::new(); // TODO: Do we need to range check `x_index` to a target smaller than `p`? let mut x_index = challenger.get_challenge(self); x_index = self.split_low_high(x_index, n_log, 64).0; @@ -287,6 +279,7 @@ impl, const D: usize> CircuitBuilder { self.mul(g, phi) }); + let mut evaluations: Vec>> = Vec::new(); for (i, &arity_bits) in config.reduction_arity_bits.iter().enumerate() { let next_domain_size = domain_size >> arity_bits; let e_x = if i == 0 { @@ -296,9 +289,9 @@ impl, const D: usize> CircuitBuilder { self.fri_combine_initial( &round_proof.initial_trees_proof, alpha, - os, zeta, subgroup_x, + precomputed_reduced_evals, common_data, ) ) @@ -322,23 +315,21 @@ impl, const D: usize> CircuitBuilder { let (low_x_index, high_x_index) = self.split_low_high(x_index, arity_bits, x_index_num_bits); evals = self.insert(low_x_index, e_x, evals); - evaluations.push(evals); context!( self, "verify FRI round Merkle proof.", self.verify_merkle_proof( - flatten_target(&evaluations[i]), + flatten_target(&evals), high_x_index, proof.commit_phase_merkle_roots[i], &round_proof.steps[i].merkle_proof, ) ); + evaluations.push(evals); if i > 0 { // Update the point x to x^arity. - for _ in 0..config.reduction_arity_bits[i - 1] { - subgroup_x = self.square(subgroup_x); - } + subgroup_x = self.exp_power_of_2(subgroup_x, config.reduction_arity_bits[i - 1]); } domain_size = next_domain_size; old_x_index = low_x_index; @@ -359,9 +350,7 @@ impl, const D: usize> CircuitBuilder { *betas.last().unwrap(), ) ); - for _ in 0..final_arity_bits { - subgroup_x = self.square(subgroup_x); - } + subgroup_x = self.exp_power_of_2(subgroup_x, final_arity_bits); // Final check of FRI. After all the reductions, we check that the final polynomial is equal // to the one sent by the prover. @@ -373,3 +362,39 @@ impl, const D: usize> CircuitBuilder { self.assert_equal_extension(eval, purported_eval); } } + +#[derive(Copy, Clone)] +struct PrecomputedReducedEvalsTarget { + pub single: ExtensionTarget, + pub zs: ExtensionTarget, + pub zs_right: ExtensionTarget, +} + +impl PrecomputedReducedEvalsTarget { + fn from_os_and_alpha>( + os: &OpeningSetTarget, + alpha: ExtensionTarget, + builder: &mut CircuitBuilder, + ) -> Self { + let mut alpha = ReducingFactorTarget::new(alpha); + let single = alpha.reduce( + &os.constants + .iter() + .chain(&os.plonk_sigmas) + .chain(&os.wires) + .chain(&os.quotient_polys) + .chain(&os.partial_products) + .copied() + .collect::>(), + builder, + ); + let zs = alpha.reduce(&os.plonk_zs, builder); + let zs_right = alpha.reduce(&os.plonk_zs_right, builder); + + Self { + single, + zs, + zs_right, + } + } +} diff --git a/src/fri/verifier.rs b/src/fri/verifier.rs index 89b60d40..ff561680 100644 --- a/src/fri/verifier.rs +++ b/src/fri/verifier.rs @@ -113,11 +113,12 @@ pub fn verify_fri_proof, const D: usize>( "Number of reductions should be non-zero." ); + let precomputed_reduced_evals = PrecomputedReducedEvals::from_os_and_alpha(os, alpha); for round_proof in &proof.query_round_proofs { fri_verifier_query_round( - os, zeta, alpha, + precomputed_reduced_evals, initial_merkle_roots, &proof, challenger, @@ -143,12 +144,43 @@ fn fri_verify_initial_proof( Ok(()) } +/// Holds the reduced (by `alpha`) evaluations at `zeta` for the polynomial opened just at +/// zeta, for `Z` at zeta and for `Z` at `g*zeta`. +#[derive(Copy, Clone)] +struct PrecomputedReducedEvals, const D: usize> { + pub single: F::Extension, + pub zs: F::Extension, + pub zs_right: F::Extension, +} + +impl, const D: usize> PrecomputedReducedEvals { + fn from_os_and_alpha(os: &OpeningSet, alpha: F::Extension) -> Self { + let mut alpha = ReducingFactor::new(alpha); + let single = alpha.reduce( + os.constants + .iter() + .chain(&os.plonk_sigmas) + .chain(&os.wires) + .chain(&os.quotient_polys) + .chain(&os.partial_products), + ); + let zs = alpha.reduce(os.plonk_zs.iter()); + let zs_right = alpha.reduce(os.plonk_zs_right.iter()); + + Self { + single, + zs, + zs_right, + } + } +} + fn fri_combine_initial, const D: usize>( proof: &FriInitialTreeProof, alpha: F::Extension, - os: &OpeningSet, zeta: F::Extension, subgroup_x: F, + precomputed_reduced_evals: PrecomputedReducedEvals, common_data: &CommonCircuitData, ) -> F::Extension { let config = &common_data.config; @@ -175,19 +207,8 @@ fn fri_combine_initial, const D: usize>( [common_data.partial_products_range()], ) .map(|&e| F::Extension::from_basefield(e)); - let single_openings = os - .constants - .iter() - .chain(&os.plonk_sigmas) - .chain(&os.wires) - .chain(&os.quotient_polys) - .chain(&os.partial_products); - let single_diffs = single_evals - .into_iter() - .zip(single_openings) - .map(|(e, &o)| e - o) - .collect::>(); - let single_numerator = alpha.reduce(single_diffs.iter()); + let single_composition_eval = alpha.reduce(single_evals); + let single_numerator = single_composition_eval - precomputed_reduced_evals.single; let single_denominator = subgroup_x - zeta; sum += single_numerator / single_denominator; alpha.reset(); @@ -198,12 +219,12 @@ fn fri_combine_initial, const D: usize>( .iter() .map(|&e| F::Extension::from_basefield(e)) .take(common_data.zs_range().end); - let zs_composition_eval = alpha.clone().reduce(zs_evals); + let zs_composition_eval = alpha.reduce(zs_evals); let zeta_right = F::Extension::primitive_root_of_unity(degree_log) * zeta; let zs_interpol = interpolate2( [ - (zeta, alpha.clone().reduce(os.plonk_zs.iter())), - (zeta_right, alpha.reduce(os.plonk_zs_right.iter())), + (zeta, precomputed_reduced_evals.zs), + (zeta_right, precomputed_reduced_evals.zs_right), ], subgroup_x, ); @@ -216,9 +237,9 @@ fn fri_combine_initial, const D: usize>( } fn fri_verifier_query_round, const D: usize>( - os: &OpeningSet, zeta: F::Extension, alpha: F::Extension, + precomputed_reduced_evals: PrecomputedReducedEvals, initial_merkle_roots: &[Hash], proof: &FriProof, challenger: &mut Challenger, @@ -228,7 +249,6 @@ fn fri_verifier_query_round, const D: usize>( common_data: &CommonCircuitData, ) -> Result<()> { let config = &common_data.config.fri_config; - let mut evaluations: Vec> = Vec::new(); let x = challenger.get_challenge(); let mut domain_size = n; let mut x_index = x.to_canonical_u64() as usize % n; @@ -242,6 +262,8 @@ fn fri_verifier_query_round, const D: usize>( let log_n = log2_strict(n); let mut subgroup_x = F::MULTIPLICATIVE_GROUP_GENERATOR * F::primitive_root_of_unity(log_n).exp(reverse_bits(x_index, log_n) as u64); + + let mut evaluations: Vec> = Vec::new(); for (i, &arity_bits) in config.reduction_arity_bits.iter().enumerate() { let arity = 1 << arity_bits; let next_domain_size = domain_size >> arity_bits; @@ -249,9 +271,9 @@ fn fri_verifier_query_round, const D: usize>( fri_combine_initial( &round_proof.initial_trees_proof, alpha, - os, zeta, subgroup_x, + precomputed_reduced_evals, common_data, ) } else { @@ -268,20 +290,18 @@ fn fri_verifier_query_round, const D: usize>( let mut evals = round_proof.steps[i].evals.clone(); // Insert P(y) into the evaluation vector, since it wasn't included by the prover. evals.insert(x_index & (arity - 1), e_x); - evaluations.push(evals); verify_merkle_proof( - flatten(&evaluations[i]), + flatten(&evals), x_index >> arity_bits, proof.commit_phase_merkle_roots[i], &round_proof.steps[i].merkle_proof, false, )?; + evaluations.push(evals); if i > 0 { // Update the point x to x^arity. - for _ in 0..config.reduction_arity_bits[i - 1] { - subgroup_x = subgroup_x.square(); - } + subgroup_x = subgroup_x.exp_power_of_2(config.reduction_arity_bits[i - 1]); } domain_size = next_domain_size; old_x_index = x_index & (arity - 1); @@ -297,9 +317,7 @@ fn fri_verifier_query_round, const D: usize>( last_evals, *betas.last().unwrap(), ); - for _ in 0..final_arity_bits { - subgroup_x = subgroup_x.square(); - } + subgroup_x = subgroup_x.exp_power_of_2(final_arity_bits); // Final check of FRI. After all the reductions, we check that the final polynomial is equal // to the one sent by the prover. diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 605061ee..88004460 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -16,7 +16,8 @@ impl, const D: usize> CircuitBuilder { /// Computes `x^3`. pub fn cube(&mut self, x: Target) -> Target { - self.mul_many(&[x, x, x]) + let xe = self.convert_to_ext(x); + self.mul_three_extension(xe, xe, xe).to_target_array()[0] } /// Computes `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`. @@ -123,13 +124,14 @@ impl, const D: usize> CircuitBuilder { self.arithmetic(F::ONE, x, one, F::ONE, y) } + /// Add `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. // TODO: Can be made `2*D` times more efficient by using all wires of an `ArithmeticExtensionGate`. pub fn add_many(&mut self, terms: &[Target]) -> Target { - let mut sum = self.zero(); - for term in terms { - sum = self.add(sum, *term); - } - sum + let terms_ext = terms + .iter() + .map(|&t| self.convert_to_ext(t)) + .collect::>(); + self.add_many_extension(&terms_ext).to_target_array()[0] } /// Computes `x - y`. @@ -145,12 +147,22 @@ impl, const D: usize> CircuitBuilder { self.arithmetic(F::ONE, x, y, F::ZERO, x) } + /// Multiply `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. pub fn mul_many(&mut self, terms: &[Target]) -> Target { - let mut product = self.one(); - for term in terms { - product = self.mul(product, *term); + let terms_ext = terms + .iter() + .map(|&t| self.convert_to_ext(t)) + .collect::>(); + self.mul_many_extension(&terms_ext).to_target_array()[0] + } + + /// Exponentiate `base` to the power of `2^power_log`. + // TODO: Test + pub fn exp_power_of_2(&mut self, mut base: Target, power_log: usize) -> Target { + for _ in 0..power_log { + base = self.square(base); } - product + base } // TODO: Optimize this, maybe with a new gate. diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 10b60dcd..e6efb451 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -7,7 +7,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::{Extendable, OEF}; use crate::gates::arithmetic::ArithmeticExtensionGate; -use crate::generator::SimpleGenerator; +use crate::generator::{GeneratedValues, SimpleGenerator}; use crate::target::Target; use crate::util::bits_u64; use crate::witness::PartialWitness; @@ -17,37 +17,47 @@ impl, const D: usize> CircuitBuilder { &mut self, const_0: F, const_1: F, - fixed_multiplicand: ExtensionTarget, - multiplicand_0: ExtensionTarget, - addend_0: ExtensionTarget, - multiplicand_1: ExtensionTarget, - addend_1: ExtensionTarget, + first_multiplicand_0: ExtensionTarget, + first_multiplicand_1: ExtensionTarget, + first_addend: ExtensionTarget, + second_multiplicand_0: ExtensionTarget, + second_multiplicand_1: ExtensionTarget, + second_addend: ExtensionTarget, ) -> (ExtensionTarget, ExtensionTarget) { let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]); - let wire_fixed_multiplicand = ExtensionTarget::from_range( + let wire_first_multiplicand_0 = ExtensionTarget::from_range( gate, - ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ArithmeticExtensionGate::::wires_first_multiplicand_0(), ); - let wire_multiplicand_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); - let wire_addend_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_0()); - let wire_multiplicand_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_1()); - let wire_addend_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_1()); - let wire_output_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); - let wire_output_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_1()); + let wire_first_multiplicand_1 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_first_multiplicand_1(), + ); + let wire_first_addend = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_addend()); + let wire_second_multiplicand_0 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_second_multiplicand_0(), + ); + let wire_second_multiplicand_1 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_second_multiplicand_1(), + ); + let wire_second_addend = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_second_addend()); + let wire_first_output = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); + let wire_second_output = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_second_output()); - self.route_extension(fixed_multiplicand, wire_fixed_multiplicand); - self.route_extension(multiplicand_0, wire_multiplicand_0); - self.route_extension(addend_0, wire_addend_0); - self.route_extension(multiplicand_1, wire_multiplicand_1); - self.route_extension(addend_1, wire_addend_1); - (wire_output_0, wire_output_1) + self.route_extension(first_multiplicand_0, wire_first_multiplicand_0); + self.route_extension(first_multiplicand_1, wire_first_multiplicand_1); + self.route_extension(first_addend, wire_first_addend); + self.route_extension(second_multiplicand_0, wire_second_multiplicand_0); + self.route_extension(second_multiplicand_1, wire_second_multiplicand_1); + self.route_extension(second_addend, wire_second_addend); + (wire_first_output, wire_second_output) } pub fn arithmetic_extension( @@ -67,6 +77,7 @@ impl, const D: usize> CircuitBuilder { addend, zero, zero, + zero, ) .0 } @@ -80,6 +91,7 @@ impl, const D: usize> CircuitBuilder { self.arithmetic_extension(F::ONE, F::ONE, one, a, b) } + /// Returns `(a0+b0, a1+b1)`. pub fn add_two_extension( &mut self, a0: ExtensionTarget, @@ -88,7 +100,7 @@ impl, const D: usize> CircuitBuilder { b1: ExtensionTarget, ) -> (ExtensionTarget, ExtensionTarget) { let one = self.one_extension(); - self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, a1, b1) + self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, one, a1, b1) } pub fn add_ext_algebra( @@ -113,20 +125,39 @@ impl, const D: usize> CircuitBuilder { ExtensionAlgebraTarget(res.try_into().unwrap()) } + /// Add 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. + pub fn add_three_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let one = self.one_extension(); + let gate = self.num_gates(); + let first_out = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); + self.double_arithmetic_extension(F::ONE, F::ONE, one, a, b, one, c, first_out) + .1 + } + + /// Add `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let zero = self.zero_extension(); let mut terms = terms.to_vec(); - if terms.len().is_odd() { + if terms.is_empty() { + return zero; + } else if terms.len() < 3 { + terms.resize(3, zero); + } else if terms.len().is_even() { terms.push(zero); } - // We maintain two accumulators, one for the sum of even elements, and one for odd elements. - let mut acc0 = zero; - let mut acc1 = zero; + + let mut acc = self.add_three_extension(terms[0], terms[1], terms[2]); + terms.drain(0..3); for chunk in terms.chunks_exact(2) { - (acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]); + acc = self.add_three_extension(acc, chunk[0], chunk[1]); } - // We sum both accumulators to get the final result. - self.add_extension(acc0, acc1) + acc } pub fn sub_extension( @@ -146,7 +177,7 @@ impl, const D: usize> CircuitBuilder { b1: ExtensionTarget, ) -> (ExtensionTarget, ExtensionTarget) { let one = self.one_extension(); - self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, a1, b1) + self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, one, a1, b1) } pub fn sub_ext_algebra( @@ -184,6 +215,7 @@ impl, const D: usize> CircuitBuilder { zero, zero, zero, + zero, ) .0 } @@ -196,6 +228,18 @@ impl, const D: usize> CircuitBuilder { self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1) } + /// Returns `(a0*b0, a1*b1)`. + pub fn mul_two_extension( + &mut self, + a0: ExtensionTarget, + b0: ExtensionTarget, + a1: ExtensionTarget, + b1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let zero = self.zero_extension(); + self.double_arithmetic_extension(F::ONE, F::ZERO, a0, b0, zero, a1, b1, zero) + } + /// Computes `x^2`. pub fn square_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { self.mul_extension(x, x) @@ -221,12 +265,38 @@ impl, const D: usize> CircuitBuilder { ExtensionAlgebraTarget(res) } + /// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. + pub fn mul_three_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let zero = self.zero_extension(); + let gate = self.num_gates(); + let first_out = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); + self.double_arithmetic_extension(F::ONE, F::ZERO, a, b, zero, c, first_out, zero) + .1 + } + + /// Multiply `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let mut product = self.one_extension(); - for term in terms { - product = self.mul_extension(product, *term); + let one = self.one_extension(); + let mut terms = terms.to_vec(); + if terms.is_empty() { + return one; + } else if terms.len() < 3 { + terms.resize(3, one); + } else if terms.len().is_even() { + terms.push(one); } - product + let mut acc = self.mul_three_extension(terms[0], terms[1], terms[2]); + terms.drain(0..3); + for chunk in terms.chunks_exact(2) { + acc = self.mul_three_extension(acc, chunk[0], chunk[1]); + } + acc } /// Like `mul_add`, but for `ExtensionTarget`s. @@ -292,7 +362,7 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of `2^power_log`. // TODO: Test - pub fn exp_power_of_2( + pub fn exp_power_of_2_extension( &mut self, mut base: ExtensionTarget, power_log: usize, @@ -384,11 +454,11 @@ impl, const D: usize> SimpleGenerator for QuotientGeneratorE deps } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let num = witness.get_extension_target(self.numerator); let dem = witness.get_extension_target(self.denominator); let quotient = num / dem; - PartialWitness::singleton_extension_target(self.quotient, quotient) + GeneratedValues::singleton_extension_target(self.quotient, quotient) } } @@ -443,6 +513,43 @@ mod tests { use crate::verifier::verify; use crate::witness::PartialWitness; + #[test] + fn test_mul_many() -> Result<()> { + type F = CrandallField; + type FF = QuarticCrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let mut builder = CircuitBuilder::::new(config); + let mut pw = PartialWitness::new(); + + let vs = FF::rand_vec(3); + let ts = builder.add_virtual_extension_targets(3); + for (&v, &t) in vs.iter().zip(&ts) { + pw.set_extension_target(t, v); + } + let mul0 = builder.mul_many_extension(&ts); + let mul1 = { + let mut acc = builder.one_extension(); + for &t in &ts { + acc = builder.mul_extension(acc, t); + } + acc + }; + let mul2 = builder.mul_three_extension(ts[0], ts[1], ts[2]); + let mul3 = builder.constant_extension(vs.into_iter().product()); + + builder.assert_equal_extension(mul0, mul1); + builder.assert_equal_extension(mul1, mul2); + builder.assert_equal_extension(mul2, mul3); + + let data = builder.build(); + let proof = data.prove(pw)?; + + verify(proof, &data.verifier_only, &data.common) + } + #[test] fn test_div_extension() -> Result<()> { type F = CrandallField; diff --git a/src/gadgets/range_check.rs b/src/gadgets/range_check.rs index 7fd35efc..c0848af8 100644 --- a/src/gadgets/range_check.rs +++ b/src/gadgets/range_check.rs @@ -2,7 +2,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::base_sum::BaseSumGate; -use crate::generator::SimpleGenerator; +use crate::generator::{GeneratedValues, SimpleGenerator}; use crate::target::Target; use crate::witness::PartialWitness; @@ -49,12 +49,12 @@ impl SimpleGenerator for LowHighGenerator { vec![self.integer] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let integer_value = witness.get_target(self.integer).to_canonical_u64(); let low = integer_value & ((1 << self.n_log) - 1); let high = integer_value >> self.n_log; - let mut result = PartialWitness::new(); + let mut result = GeneratedValues::with_capacity(2); result.set_target(self.low, F::from_canonical_u64(low)); result.set_target(self.high, F::from_canonical_u64(high)); diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 3a2c27f4..9cc6ab7c 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -2,7 +2,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::base_sum::BaseSumGate; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::util::ceil_div_usize; use crate::wire::Wire; @@ -110,10 +110,10 @@ impl SimpleGenerator for SplitGenerator { vec![self.integer] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let mut integer_value = witness.get_target(self.integer).to_canonical_u64(); - let mut result = PartialWitness::new(); + let mut result = GeneratedValues::with_capacity(self.bits.len()); for &b in &self.bits { let b_value = integer_value & 1; result.set_target(b, F::from_canonical_u64(b_value)); @@ -141,10 +141,10 @@ impl SimpleGenerator for WireSplitGenerator { vec![self.integer] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let mut integer_value = witness.get_target(self.integer).to_canonical_u64(); - let mut result = PartialWitness::new(); + let mut result = GeneratedValues::with_capacity(self.gates.len()); for &gate in &self.gates { let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); result.set_target( diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 39baa226..cf39e09b 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -4,7 +4,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::witness::PartialWitness; @@ -18,27 +18,30 @@ impl ArithmeticExtensionGate { GateRef::new(ArithmeticExtensionGate) } - pub fn wires_fixed_multiplicand() -> Range { + pub fn wires_first_multiplicand_0() -> Range { 0..D } - pub fn wires_multiplicand_0() -> Range { + pub fn wires_first_multiplicand_1() -> Range { D..2 * D } - pub fn wires_addend_0() -> Range { + pub fn wires_first_addend() -> Range { 2 * D..3 * D } - pub fn wires_multiplicand_1() -> Range { + pub fn wires_second_multiplicand_0() -> Range { 3 * D..4 * D } - pub fn wires_addend_1() -> Range { + pub fn wires_second_multiplicand_1() -> Range { 4 * D..5 * D } - pub fn wires_output_0() -> Range { + pub fn wires_second_addend() -> Range { 5 * D..6 * D } - pub fn wires_output_1() -> Range { + pub fn wires_first_output() -> Range { 6 * D..7 * D } + pub fn wires_second_output() -> Range { + 7 * D..8 * D + } } impl, const D: usize> Gate for ArithmeticExtensionGate { @@ -50,21 +53,24 @@ impl, const D: usize> Gate for ArithmeticExtensionGate let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); - let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); - let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); - let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); - let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); - let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + let first_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_0()); + let first_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_1()); + let first_addend = vars.get_local_ext_algebra(Self::wires_first_addend()); + let second_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_0()); + let second_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_1()); + let second_addend = vars.get_local_ext_algebra(Self::wires_second_addend()); + let first_output = vars.get_local_ext_algebra(Self::wires_first_output()); + let second_output = vars.get_local_ext_algebra(Self::wires_second_output()); - let computed_output_0 = - fixed_multiplicand * multiplicand_0 * const_0.into() + addend_0 * const_1.into(); - let computed_output_1 = - fixed_multiplicand * multiplicand_1 * const_0.into() + addend_1 * const_1.into(); + let first_computed_output = first_multiplicand_0 * first_multiplicand_1 * const_0.into() + + first_addend * const_1.into(); + let second_computed_output = second_multiplicand_0 * second_multiplicand_1 * const_0.into() + + second_addend * const_1.into(); - let mut constraints = (output_0 - computed_output_0).to_basefield_array().to_vec(); - constraints.extend((output_1 - computed_output_1).to_basefield_array()); + let mut constraints = (first_output - first_computed_output) + .to_basefield_array() + .to_vec(); + constraints.extend((second_output - second_computed_output).to_basefield_array()); constraints } @@ -76,26 +82,32 @@ impl, const D: usize> Gate for ArithmeticExtensionGate let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); - let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); - let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); - let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); - let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); - let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + let first_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_0()); + let first_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_1()); + let first_addend = vars.get_local_ext_algebra(Self::wires_first_addend()); + let second_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_0()); + let second_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_1()); + let second_addend = vars.get_local_ext_algebra(Self::wires_second_addend()); + let first_output = vars.get_local_ext_algebra(Self::wires_first_output()); + let second_output = vars.get_local_ext_algebra(Self::wires_second_output()); - let computed_output_0 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_0); - let computed_output_0 = builder.scalar_mul_ext_algebra(const_0, computed_output_0); - let scaled_addend_0 = builder.scalar_mul_ext_algebra(const_1, addend_0); - let computed_output_0 = builder.add_ext_algebra(computed_output_0, scaled_addend_0); + let first_computed_output = + builder.mul_ext_algebra(first_multiplicand_0, first_multiplicand_1); + let first_computed_output = builder.scalar_mul_ext_algebra(const_0, first_computed_output); + let first_scaled_addend = builder.scalar_mul_ext_algebra(const_1, first_addend); + let first_computed_output = + builder.add_ext_algebra(first_computed_output, first_scaled_addend); - let computed_output_1 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_1); - let computed_output_1 = builder.scalar_mul_ext_algebra(const_0, computed_output_1); - let scaled_addend_1 = builder.scalar_mul_ext_algebra(const_1, addend_1); - let computed_output_1 = builder.add_ext_algebra(computed_output_1, scaled_addend_1); + let second_computed_output = + builder.mul_ext_algebra(second_multiplicand_0, second_multiplicand_1); + let second_computed_output = + builder.scalar_mul_ext_algebra(const_0, second_computed_output); + let second_scaled_addend = builder.scalar_mul_ext_algebra(const_1, second_addend); + let second_computed_output = + builder.add_ext_algebra(second_computed_output, second_scaled_addend); - let diff_0 = builder.sub_ext_algebra(output_0, computed_output_0); - let diff_1 = builder.sub_ext_algebra(output_1, computed_output_1); + let diff_0 = builder.sub_ext_algebra(first_output, first_computed_output); + let diff_1 = builder.sub_ext_algebra(second_output, second_computed_output); let mut constraints = diff_0.to_ext_target_array().to_vec(); constraints.extend(diff_1.to_ext_target_array()); constraints @@ -120,7 +132,7 @@ impl, const D: usize> Gate for ArithmeticExtensionGate } fn num_wires(&self) -> usize { - 7 * D + 8 * D } fn num_constants(&self) -> usize { @@ -150,67 +162,67 @@ struct ArithmeticExtensionGenerator1, const D: usize> { impl, const D: usize> SimpleGenerator for ArithmeticExtensionGenerator0 { fn dependencies(&self) -> Vec { - ArithmeticExtensionGate::::wires_fixed_multiplicand() - .chain(ArithmeticExtensionGate::::wires_multiplicand_0()) - .chain(ArithmeticExtensionGate::::wires_addend_0()) + ArithmeticExtensionGate::::wires_first_multiplicand_0() + .chain(ArithmeticExtensionGate::::wires_first_multiplicand_1()) + .chain(ArithmeticExtensionGate::::wires_first_addend()) .map(|i| Target::wire(self.gate_index, i)) .collect() } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let extract_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.gate_index, range); witness.get_extension_target(t) }; - let fixed_multiplicand = - extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); let multiplicand_0 = - extract_extension(ArithmeticExtensionGate::::wires_multiplicand_0()); - let addend_0 = extract_extension(ArithmeticExtensionGate::::wires_addend_0()); + extract_extension(ArithmeticExtensionGate::::wires_first_multiplicand_0()); + let multiplicand_1 = + extract_extension(ArithmeticExtensionGate::::wires_first_multiplicand_1()); + let addend = extract_extension(ArithmeticExtensionGate::::wires_first_addend()); - let output_target_0 = ExtensionTarget::from_range( + let output_target = ExtensionTarget::from_range( self.gate_index, - ArithmeticExtensionGate::::wires_output_0(), + ArithmeticExtensionGate::::wires_first_output(), ); - let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into() - + addend_0 * self.const_1.into(); + let computed_output = + multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into(); - PartialWitness::singleton_extension_target(output_target_0, computed_output_0) + GeneratedValues::singleton_extension_target(output_target, computed_output) } } impl, const D: usize> SimpleGenerator for ArithmeticExtensionGenerator1 { fn dependencies(&self) -> Vec { - ArithmeticExtensionGate::::wires_fixed_multiplicand() - .chain(ArithmeticExtensionGate::::wires_multiplicand_1()) - .chain(ArithmeticExtensionGate::::wires_addend_1()) + ArithmeticExtensionGate::::wires_second_multiplicand_0() + .chain(ArithmeticExtensionGate::::wires_second_multiplicand_1()) + .chain(ArithmeticExtensionGate::::wires_second_addend()) .map(|i| Target::wire(self.gate_index, i)) .collect() } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let extract_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.gate_index, range); witness.get_extension_target(t) }; - let fixed_multiplicand = - extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); + let multiplicand_0 = + extract_extension(ArithmeticExtensionGate::::wires_second_multiplicand_0()); let multiplicand_1 = - extract_extension(ArithmeticExtensionGate::::wires_multiplicand_1()); - let addend_1 = extract_extension(ArithmeticExtensionGate::::wires_addend_1()); + extract_extension(ArithmeticExtensionGate::::wires_second_multiplicand_1()); + let addend = extract_extension(ArithmeticExtensionGate::::wires_second_addend()); - let output_target_1 = ExtensionTarget::from_range( + let output_target = ExtensionTarget::from_range( self.gate_index, - ArithmeticExtensionGate::::wires_output_1(), + ArithmeticExtensionGate::::wires_second_output(), ); - let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into() - + addend_1 * self.const_1.into(); + let computed_output = + multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into(); - PartialWitness::singleton_extension_target(output_target_1, computed_output_1) + GeneratedValues::singleton_extension_target(output_target, computed_output) } } diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index 8f453d8e..8ad189ee 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -5,7 +5,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::plonk_common::{reduce_with_powers, reduce_with_powers_recursive}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; @@ -130,7 +130,7 @@ impl SimpleGenerator for BaseSplitGenerator { vec![Target::wire(self.gate_index, BaseSumGate::::WIRE_SUM)] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let sum_value = witness .get_target(Target::wire(self.gate_index, BaseSumGate::::WIRE_SUM)) .to_canonical_u64() as usize; @@ -155,7 +155,7 @@ impl SimpleGenerator for BaseSplitGenerator { .iter() .fold(F::ZERO, |acc, &x| acc * b_field + x); - let mut result = PartialWitness::new(); + let mut result = GeneratedValues::with_capacity(self.num_limbs + 1); result.set_target( Target::wire(self.gate_index, BaseSumGate::::WIRE_REVERSED_SUM), reversed_sum, diff --git a/src/gates/constant.rs b/src/gates/constant.rs index 3845031a..4049d058 100644 --- a/src/gates/constant.rs +++ b/src/gates/constant.rs @@ -3,7 +3,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; @@ -83,12 +83,12 @@ impl SimpleGenerator for ConstantGenerator { Vec::new() } - fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, _witness: &PartialWitness) -> GeneratedValues { let wire = Wire { gate: self.gate_index, input: ConstantGate::WIRE_OUTPUT, }; - PartialWitness::singleton_target(Target::Wire(wire), self.constant) + GeneratedValues::singleton_target(Target::Wire(wire), self.constant) } } diff --git a/src/gates/gate.rs b/src/gates/gate.rs index 9b45817e..17059332 100644 --- a/src/gates/gate.rs +++ b/src/gates/gate.rs @@ -31,9 +31,11 @@ pub trait Gate, const D: usize>: 'static + Send + Sync { .iter() .map(|w| F::Extension::from_basefield(*w)) .collect::>(); + let public_inputs_hash = &vars_base.public_inputs_hash; let vars = EvaluationVars { local_constants, local_wires, + public_inputs_hash, }; let values = self.eval_unfiltered(vars); diff --git a/src/gates/gate_testing.rs b/src/gates/gate_testing.rs index cb25acb8..2f7a3020 100644 --- a/src/gates/gate_testing.rs +++ b/src/gates/gate_testing.rs @@ -2,6 +2,7 @@ use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::GateRef; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::proof::Hash; use crate::util::{log2_ceil, transpose}; use crate::vars::EvaluationVars; @@ -17,6 +18,7 @@ pub(crate) fn test_low_degree, const D: usize>(gate: GateRef(gate.num_wires(), rate_bits); let constant_ldes = random_low_degree_matrix::(gate.num_constants(), rate_bits); assert_eq!(wire_ldes.len(), constant_ldes.len()); + let public_inputs_hash = &Hash::rand(); let constraint_evals = wire_ldes .iter() @@ -24,6 +26,7 @@ pub(crate) fn test_low_degree, const D: usize>(gate: GateRef>(); diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 768a5693..0404884b 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -5,7 +5,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::gmimc::gmimc_automatic_constants; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; @@ -239,8 +239,8 @@ impl, const D: usize, const R: usize> SimpleGenerator .collect() } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let mut result = PartialWitness::new(); + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + let mut result = GeneratedValues::with_capacity(R + W + 1); let mut state = (0..W) .map(|i| { @@ -326,6 +326,7 @@ mod tests { use crate::gates::gmimc::{GMiMCGate, W}; use crate::generator::generate_partial_witness; use crate::gmimc::gmimc_permute_naive; + use crate::proof::Hash; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::verifier::verify; use crate::wire::Wire; @@ -416,9 +417,11 @@ mod tests { let gate = Gate::with_constants(constants); let wires = FF::rand_vec(Gate::end()); + let public_inputs_hash = &Hash::rand(); let vars = EvaluationVars { local_constants: &[], local_wires: &wires, + public_inputs_hash, }; let ev = gate.0.eval_unfiltered(vars); @@ -427,9 +430,14 @@ mod tests { for i in 0..Gate::end() { pw.set_extension_target(wires_t[i], wires[i]); } + + let public_inputs_hash_t = builder.add_virtual_hash(); + pw.set_hash_target(public_inputs_hash_t, *public_inputs_hash); + let vars_t = EvaluationTargets { local_constants: &[], local_wires: &wires_t, + public_inputs_hash: &public_inputs_hash_t, }; let ev_t = gate.0.eval_unfiltered_recursively(&mut builder, vars_t); diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index 64301347..1bc0b454 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -7,7 +7,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; @@ -218,7 +218,7 @@ impl, const D: usize> SimpleGenerator for InsertionGenerator deps } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let local_wire = |input| Wire { gate: self.gate_index, input, @@ -264,7 +264,7 @@ impl, const D: usize> SimpleGenerator for InsertionGenerator let mut insert_here_vals = vec![F::ZERO; vec_size]; insert_here_vals.insert(insertion_index, F::ONE); - let mut result = PartialWitness::::new(); + let mut result = GeneratedValues::::with_capacity((vec_size + 1) * (D + 2)); for i in 0..=vec_size { let output_wires = self.gate.wires_output_list_item(i).map(local_wire); result.set_ext_wires(output_wires, new_vec[i]); @@ -288,6 +288,7 @@ mod tests { use crate::gates::gate::Gate; use crate::gates::gate_testing::test_low_degree; use crate::gates::insertion::InsertionGate; + use crate::proof::Hash; use crate::vars::EvaluationVars; #[test] @@ -366,6 +367,7 @@ mod tests { let vars = EvaluationVars { local_constants: &[], local_wires: &get_wires(orig_vec, insertion_index, element_to_insert), + public_inputs_hash: &Hash::rand(), }; assert!( diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index ccf8d57d..17d34e3a 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -9,7 +9,7 @@ use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::interpolation::interpolant; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; @@ -216,7 +216,7 @@ impl, const D: usize> SimpleGenerator for InterpolationGener deps } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let n = self.gate.num_points; let local_wire = |input| Wire { @@ -244,7 +244,7 @@ impl, const D: usize> SimpleGenerator for InterpolationGener .collect::>(); let interpolant = interpolant(&points); - let mut result = PartialWitness::::new(); + let mut result = GeneratedValues::::with_capacity(D * (self.gate.num_points + 1)); for (i, &coeff) in interpolant.coeffs.iter().enumerate() { let wires = self.gate.wires_coeff(i).map(local_wire); result.set_ext_wires(wires, coeff); @@ -271,6 +271,7 @@ mod tests { use crate::gates::gate_testing::test_low_degree; use crate::gates::interpolation::InterpolationGate; use crate::polynomial::polynomial::PolynomialCoeffs; + use crate::proof::Hash; use crate::vars::EvaluationVars; #[test] @@ -352,6 +353,7 @@ mod tests { let vars = EvaluationVars { local_constants: &[], local_wires: &get_wires(2, coeffs, points, eval_point), + public_inputs_hash: &Hash::rand(), }; assert!( diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 6b3f05fa..441383ec 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -10,6 +10,7 @@ pub mod gmimc; pub mod insertion; pub mod interpolation; pub(crate) mod noop; +pub(crate) mod public_input; #[cfg(test)] mod gate_testing; diff --git a/src/gates/public_input.rs b/src/gates/public_input.rs new file mode 100644 index 00000000..a86b78d5 --- /dev/null +++ b/src/gates/public_input.rs @@ -0,0 +1,84 @@ +use std::ops::Range; + +use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::gates::gate::{Gate, GateRef}; +use crate::generator::WitnessGenerator; +use crate::vars::{EvaluationTargets, EvaluationVars}; + +/// A gate whose first four wires will be equal to a hash of public inputs. +pub struct PublicInputGate; + +impl PublicInputGate { + pub fn get, const D: usize>() -> GateRef { + GateRef::new(PublicInputGate) + } + + pub fn wires_public_inputs_hash() -> Range { + 0..4 + } +} + +impl, const D: usize> Gate for PublicInputGate { + fn id(&self) -> String { + "PublicInputGate".into() + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + Self::wires_public_inputs_hash() + .zip(vars.public_inputs_hash.elements) + .map(|(wire, hash_part)| vars.local_wires[wire] - hash_part.into()) + .collect() + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + Self::wires_public_inputs_hash() + .zip(vars.public_inputs_hash.elements) + .map(|(wire, hash_part)| { + let hash_part_ext = builder.convert_to_ext(hash_part); + builder.sub_extension(vars.local_wires[wire], hash_part_ext) + }) + .collect() + } + + fn generators( + &self, + _gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + Vec::new() + } + + fn num_wires(&self) -> usize { + 4 + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 + } + + fn num_constraints(&self) -> usize { + 4 + } +} + +#[cfg(test)] +mod tests { + use crate::field::crandall_field::CrandallField; + use crate::gates::gate_testing::test_low_degree; + use crate::gates::public_input::PublicInputGate; + + #[test] + fn low_degree() { + test_low_degree(PublicInputGate::get::()) + } +} diff --git a/src/generator.rs b/src/generator.rs index a47c5267..a7359a7d 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -1,8 +1,13 @@ use std::collections::{HashMap, HashSet}; +use std::convert::identity; use std::fmt::Debug; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; +use crate::proof::{Hash, HashTarget}; use crate::target::Target; +use crate::wire::Wire; use crate::witness::PartialWitness; /// Given a `PartialWitness` that has only inputs set, populates the rest of the witness using the @@ -27,7 +32,7 @@ pub(crate) fn generate_partial_witness( let mut pending_generator_indices: HashSet<_> = (0..generators.len()).collect(); // We also track a list of "expired" generators which have already returned false. - let mut expired_generator_indices = HashSet::new(); + let mut generator_is_expired = vec![false; generators.len()]; // Keep running generators until no generators are queued. while !pending_generator_indices.is_empty() { @@ -36,15 +41,15 @@ pub(crate) fn generate_partial_witness( for &generator_idx in &pending_generator_indices { let (result, finished) = generators[generator_idx].run(&witness); if finished { - expired_generator_indices.insert(generator_idx); + generator_is_expired[generator_idx] = true; } // Enqueue unfinished generators that were watching one of the newly populated targets. - for watch in result.target_values.keys() { + for (watch, _) in &result.target_values { if let Some(watching_generator_indices) = generator_indices_by_watches.get(watch) { - for watching_generator_idx in watching_generator_indices { - if !expired_generator_indices.contains(watching_generator_idx) { - next_pending_generator_indices.insert(*watching_generator_idx); + for &watching_generator_idx in watching_generator_indices { + if !generator_is_expired[watching_generator_idx] { + next_pending_generator_indices.insert(watching_generator_idx); } } } @@ -55,9 +60,9 @@ pub(crate) fn generate_partial_witness( pending_generator_indices = next_pending_generator_indices; } - assert_eq!( - expired_generator_indices.len(), - generators.len(), + + assert!( + generator_is_expired.into_iter().all(identity), "Some generators weren't run." ); } @@ -72,14 +77,101 @@ pub trait WitnessGenerator: 'static + Send + Sync { /// flag indicating whether the generator is finished. If the flag is true, the generator will /// never be run again, otherwise it will be queued for another run next time a target in its /// watch list is populated. - fn run(&self, witness: &PartialWitness) -> (PartialWitness, bool); + fn run(&self, witness: &PartialWitness) -> (GeneratedValues, bool); +} + +/// Values generated by a generator invocation. +pub struct GeneratedValues { + pub(crate) target_values: Vec<(Target, F)>, +} + +impl From> for GeneratedValues { + fn from(target_values: Vec<(Target, F)>) -> Self { + Self { target_values } + } +} + +impl GeneratedValues { + pub fn with_capacity(capacity: usize) -> Self { + Vec::with_capacity(capacity).into() + } + + pub fn empty() -> Self { + Vec::new().into() + } + + pub fn singleton_wire(wire: Wire, value: F) -> Self { + Self::singleton_target(Target::Wire(wire), value) + } + + pub fn singleton_target(target: Target, value: F) -> Self { + vec![(target, value)].into() + } + + pub fn singleton_extension_target( + et: ExtensionTarget, + value: F::Extension, + ) -> Self + where + F: Extendable, + { + let mut witness = Self::with_capacity(D); + witness.set_extension_target(et, value); + witness + } + + pub fn set_target(&mut self, target: Target, value: F) { + self.target_values.push((target, value)) + } + + pub fn set_hash_target(&mut self, ht: HashTarget, value: Hash) { + ht.elements + .iter() + .zip(value.elements) + .for_each(|(&t, x)| self.set_target(t, x)); + } + + pub fn set_extension_target( + &mut self, + et: ExtensionTarget, + value: F::Extension, + ) where + F: Extendable, + { + let limbs = value.to_basefield_array(); + (0..D).for_each(|i| { + self.set_target(et.0[i], limbs[i]); + }); + } + + pub fn set_wire(&mut self, wire: Wire, value: F) { + self.set_target(Target::Wire(wire), value) + } + + pub fn set_wires(&mut self, wires: W, values: &[F]) + where + W: IntoIterator, + { + // If we used itertools, we could use zip_eq for extra safety. + for (wire, &value) in wires.into_iter().zip(values) { + self.set_wire(wire, value); + } + } + + pub fn set_ext_wires(&mut self, wires: W, value: F::Extension) + where + F: Extendable, + W: IntoIterator, + { + self.set_wires(wires, &value.to_basefield_array()); + } } /// A generator which runs once after a list of dependencies is present in the witness. pub trait SimpleGenerator: 'static + Send + Sync { fn dependencies(&self) -> Vec; - fn run_once(&self, witness: &PartialWitness) -> PartialWitness; + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues; } impl> WitnessGenerator for SG { @@ -87,11 +179,11 @@ impl> WitnessGenerator for SG { self.dependencies() } - fn run(&self, witness: &PartialWitness) -> (PartialWitness, bool) { + fn run(&self, witness: &PartialWitness) -> (GeneratedValues, bool) { if witness.contains_all(&self.dependencies()) { (self.run_once(witness), true) } else { - (PartialWitness::new(), false) + (GeneratedValues::empty(), false) } } } @@ -108,9 +200,9 @@ impl SimpleGenerator for CopyGenerator { vec![self.src] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let value = witness.get_target(self.src); - PartialWitness::singleton_target(self.dst, value) + GeneratedValues::singleton_target(self.dst, value) } } @@ -124,10 +216,10 @@ impl SimpleGenerator for RandomValueGenerator { Vec::new() } - fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, _witness: &PartialWitness) -> GeneratedValues { let random_value = F::rand(); - PartialWitness::singleton_target(self.target, random_value) + GeneratedValues::singleton_target(self.target, random_value) } } @@ -142,7 +234,7 @@ impl SimpleGenerator for NonzeroTestGenerator { vec![self.to_test] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let to_test_value = witness.get_target(self.to_test); let dummy_value = if to_test_value == F::ZERO { @@ -151,6 +243,6 @@ impl SimpleGenerator for NonzeroTestGenerator { to_test_value.inverse() }; - PartialWitness::singleton_target(self.dummy, dummy_value) + GeneratedValues::singleton_target(self.dummy, dummy_value) } } diff --git a/src/permutation_argument.rs b/src/permutation_argument.rs index 3c98c801..7bad6d8d 100644 --- a/src/permutation_argument.rs +++ b/src/permutation_argument.rs @@ -90,7 +90,7 @@ impl usize> TargetPartition { } let mut indices = HashMap::new(); - // // Here we keep just the Wire targets, filtering out everything else. + // Here we keep just the Wire targets, filtering out everything else. let partition = partition .into_values() .map(|v| { diff --git a/src/polynomial/commitment.rs b/src/polynomial/commitment.rs index e50839e7..6bbba08d 100644 --- a/src/polynomial/commitment.rs +++ b/src/polynomial/commitment.rs @@ -34,10 +34,7 @@ impl ListPolynomialCommitment { /// Creates a list polynomial commitment for the polynomials interpolating the values in `values`. pub fn new(values: Vec>, rate_bits: usize, blinding: bool) -> Self { let degree = values[0].len(); - let polynomials = values - .par_iter() - .map(|v| v.clone().ifft()) - .collect::>(); + let polynomials = values.par_iter().map(|v| v.ifft()).collect::>(); let lde_values = timed!( Self::lde_values(&polynomials, rate_bits, blinding), "to compute LDE" @@ -92,7 +89,7 @@ impl ListPolynomialCommitment { .par_iter() .map(|p| { assert_eq!(p.len(), degree, "Polynomial degree invalid."); - p.clone().lde(rate_bits).coset_fft(F::coset_shift()).values + p.lde(rate_bits).coset_fft(F::coset_shift()).values }) .chain(if blinding { // If blinding, salt with two random elements to each leaf vector. @@ -182,15 +179,15 @@ impl ListPolynomialCommitment { final_poly += zs_quotient; let lde_final_poly = final_poly.lde(config.rate_bits); - let lde_final_values = lde_final_poly.clone().coset_fft(F::coset_shift().into()); + let lde_final_values = lde_final_poly.coset_fft(F::coset_shift().into()); let fri_proof = fri_proof( &commitments .par_iter() .map(|c| &c.merkle_tree) .collect::>(), - &lde_final_poly, - &lde_final_values, + lde_final_poly, + lde_final_values, challenger, &config.fri_config, ); diff --git a/src/polynomial/division.rs b/src/polynomial/division.rs index 50e1f8a6..6fa636ea 100644 --- a/src/polynomial/division.rs +++ b/src/polynomial/division.rs @@ -88,7 +88,7 @@ impl PolynomialCoeffs { let root = F::primitive_root_of_unity(log2_strict(a.len())); // Equals to the evaluation of `a` on `{g.w^i}`. - let mut a_eval = fft(a); + let mut a_eval = fft(&a); // Compute the denominators `1/(g^n.w^(n*i) - 1)` using batch inversion. let denominator_g = g.exp(n as u64); let root_n = root.exp(n as u64); @@ -112,7 +112,7 @@ impl PolynomialCoeffs { *x *= d; }); // `p` is the interpolating polynomial of `a_eval` on `{w^i}`. - let mut p = ifft(a_eval); + let mut p = ifft(&a_eval); // We need to scale it by `g^(-i)` to get the interpolating polynomial of `a_eval` on `{g.w^i}`, // a.k.a `a/Z_H`. let g_inv = g.inverse(); diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index 5f9bccd9..3e11f200 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -33,12 +33,12 @@ impl PolynomialValues { self.values.len() } - pub fn ifft(self) -> PolynomialCoeffs { + pub fn ifft(&self) -> PolynomialCoeffs { ifft(self) } /// Returns the polynomial whose evaluation on the coset `shift*H` is `self`. - pub fn coset_ifft(self, shift: F) -> PolynomialCoeffs { + pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs { let mut shifted_coeffs = self.ifft(); shifted_coeffs .coeffs @@ -54,9 +54,9 @@ impl PolynomialValues { polys.into_iter().map(|p| p.lde(rate_bits)).collect() } - pub fn lde(self, rate_bits: usize) -> Self { + pub fn lde(&self, rate_bits: usize) -> Self { let coeffs = ifft(self).lde(rate_bits); - fft_with_options(coeffs, Some(rate_bits), None) + fft_with_options(&coeffs, Some(rate_bits), None) } pub fn degree(&self) -> usize { @@ -66,7 +66,7 @@ impl PolynomialValues { } pub fn degree_plus_one(&self) -> usize { - self.clone().ifft().degree_plus_one() + self.ifft().degree_plus_one() } } @@ -136,7 +136,7 @@ impl PolynomialCoeffs { .fold(F::ZERO, |acc, &c| acc * x + c) } - pub fn lde_multiple(polys: Vec, rate_bits: usize) -> Vec { + pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec { polys.into_iter().map(|p| p.lde(rate_bits)).collect() } @@ -194,16 +194,16 @@ impl PolynomialCoeffs { Self::new(self.trimmed().coeffs.into_iter().rev().collect()) } - pub fn fft(self) -> PolynomialValues { + pub fn fft(&self) -> PolynomialValues { fft(self) } /// Returns the evaluation of the polynomial on the coset `shift*H`. - pub fn coset_fft(self, shift: F) -> PolynomialValues { + pub fn coset_fft(&self, shift: F) -> PolynomialValues { let modified_poly: Self = shift .powers() - .zip(self.coeffs) - .map(|(r, c)| r * c) + .zip(&self.coeffs) + .map(|(r, &c)| r * c) .collect::>() .into(); modified_poly.fft() @@ -262,8 +262,7 @@ impl Sub for &PolynomialCoeffs { fn sub(self, rhs: Self) -> Self::Output { let len = max(self.len(), rhs.len()); - let mut coeffs = self.coeffs.clone(); - coeffs.resize(len, F::ZERO); + let mut coeffs = self.padded(len).coeffs; for (i, &c) in rhs.coeffs.iter().enumerate() { coeffs[i] -= c; } @@ -343,7 +342,7 @@ impl Mul for &PolynomialCoeffs { .zip(b_evals.values) .map(|(pa, pb)| pa * pb) .collect(); - ifft(mul_evals.into()) + ifft(&mul_evals.into()) } } @@ -390,7 +389,7 @@ mod tests { let n = 1 << k; let poly = PolynomialCoeffs::new(F::rand_vec(n)); let shift = F::rand(); - let coset_evals = poly.clone().coset_fft(shift).values; + let coset_evals = poly.coset_fft(shift).values; let generator = F::primitive_root_of_unity(k); let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) @@ -411,7 +410,7 @@ mod tests { let n = 1 << k; let evals = PolynomialValues::new(F::rand_vec(n)); let shift = F::rand(); - let coeffs = evals.clone().coset_ifft(shift); + let coeffs = evals.coset_ifft(shift); let generator = F::primitive_root_of_unity(k); let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) diff --git a/src/proof.rs b/src/proof.rs index 71edba03..47be8569 100644 --- a/src/proof.rs +++ b/src/proof.rs @@ -37,6 +37,12 @@ impl Hash { elements: [elements[0], elements[1], elements[2], elements[3]], } } + + pub(crate) fn rand() -> Self { + Self { + elements: [F::rand(), F::rand(), F::rand(), F::rand()], + } + } } /// Represents a ~256 bit hash output. @@ -79,6 +85,13 @@ pub struct Proof, const D: usize> { pub opening_proof: OpeningProof, } +#[derive(Serialize, Deserialize, Clone, Debug)] +#[serde(bound = "")] +pub struct ProofWithPublicInputs, const D: usize> { + pub proof: Proof, + pub public_inputs: Vec, +} + pub struct ProofTarget { pub wires_root: HashTarget, pub plonk_zs_partial_products_root: HashTarget, @@ -87,6 +100,11 @@ pub struct ProofTarget { pub opening_proof: OpeningProofTarget, } +pub struct ProofWithPublicInputsTarget { + pub proof: ProofTarget, + pub public_inputs: Vec, +} + /// Evaluations and Merkle proof produced by the prover in a FRI query step. #[derive(Serialize, Deserialize, Clone, Debug)] #[serde(bound = "")] diff --git a/src/prover.rs b/src/prover.rs index e7c85130..f1cef856 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -7,11 +7,12 @@ use rayon::prelude::*; use crate::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; use crate::field::extension_field::Extendable; use crate::generator::generate_partial_witness; +use crate::hash::hash_n_to_hash; use crate::plonk_challenger::Challenger; use crate::plonk_common::{PlonkPolynomials, ZeroPolyOnCoset}; use crate::polynomial::commitment::ListPolynomialCommitment; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; -use crate::proof::Proof; +use crate::proof::{Hash, Proof, ProofWithPublicInputs}; use crate::timed; use crate::util::partial_products::partial_products; use crate::util::{log2_ceil, transpose}; @@ -23,7 +24,7 @@ pub(crate) fn prove, const D: usize>( prover_data: &ProverOnlyCircuitData, common_data: &CommonCircuitData, inputs: PartialWitness, -) -> Result> { +) -> Result> { let config = &common_data.config; let num_wires = config.num_wires; let num_challenges = config.num_challenges; @@ -39,6 +40,9 @@ pub(crate) fn prove, const D: usize>( "to generate witness" ); + let public_inputs = partial_witness.get_targets(&prover_data.public_inputs); + let public_inputs_hash = hash_n_to_hash(public_inputs.clone(), true); + // Display the marked targets for debugging purposes. for m in &prover_data.marked_targets { m.display(&partial_witness); @@ -58,7 +62,7 @@ pub(crate) fn prove, const D: usize>( let wires_values: Vec> = timed!( witness .wire_values - .iter() + .par_iter() .map(|column| PolynomialValues::new(column.clone())) .collect(), "to compute wire polynomials" @@ -119,6 +123,7 @@ pub(crate) fn prove, const D: usize>( compute_quotient_polys( common_data, prover_data, + &public_inputs_hash, &wires_commitment, &zs_partial_products_commitment, &betas, @@ -178,12 +183,16 @@ pub(crate) fn prove, const D: usize>( start_proof_gen.elapsed().as_secs_f32() ); - Ok(Proof { + let proof = Proof { wires_root: wires_commitment.merkle_tree.root, plonk_zs_partial_products_root: zs_partial_products_commitment.merkle_tree.root, quotient_polys_root: quotient_polys_commitment.merkle_tree.root, openings, opening_proof, + }; + Ok(ProofWithPublicInputs { + proof, + public_inputs, }) } @@ -284,6 +293,7 @@ fn compute_z, const D: usize>( fn compute_quotient_polys<'a, F: Extendable, const D: usize>( common_data: &CommonCircuitData, prover_data: &'a ProverOnlyCircuitData, + public_inputs_hash: &Hash, wires_commitment: &'a ListPolynomialCommitment, zs_partial_products_commitment: &'a ListPolynomialCommitment, betas: &[F], @@ -337,6 +347,7 @@ fn compute_quotient_polys<'a, F: Extendable, const D: usize>( let vars = EvaluationVarsBase { local_constants, local_wires, + public_inputs_hash, }; let mut quotient_values = eval_vanishing_poly_base( common_data, diff --git a/src/recursive_verifier.rs b/src/recursive_verifier.rs index 51bfe6ba..66b63e7f 100644 --- a/src/recursive_verifier.rs +++ b/src/recursive_verifier.rs @@ -3,7 +3,7 @@ use crate::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarge use crate::context; use crate::field::extension_field::Extendable; use crate::plonk_challenger::RecursiveChallenger; -use crate::proof::{HashTarget, ProofTarget}; +use crate::proof::{HashTarget, ProofWithPublicInputsTarget}; use crate::util::scaling::ReducingFactorTarget; use crate::vanishing_poly::eval_vanishing_poly_recursively; use crate::vars::EvaluationTargets; @@ -15,15 +15,21 @@ impl, const D: usize> CircuitBuilder { /// Recursively verifies an inner proof. pub fn add_recursive_verifier( &mut self, - proof: ProofTarget, + proof_with_pis: ProofWithPublicInputsTarget, inner_config: &CircuitConfig, inner_verifier_data: &VerifierCircuitTarget, inner_common_data: &CommonCircuitData, ) { assert!(self.config.num_wires >= MIN_WIRES); assert!(self.config.num_wires >= MIN_ROUTED_WIRES); + let ProofWithPublicInputsTarget { + proof, + public_inputs, + } = proof_with_pis; let one = self.one_extension(); + let public_inputs_hash = &self.hash_n_to_hash(public_inputs, true); + let num_challenges = inner_config.num_challenges; let mut challenger = RecursiveChallenger::new(self); @@ -53,13 +59,14 @@ impl, const D: usize> CircuitBuilder { let vars = EvaluationTargets { local_constants, local_wires, + public_inputs_hash, }; let local_zs = &proof.openings.plonk_zs; let next_zs = &proof.openings.plonk_zs_right; let s_sigmas = &proof.openings.plonk_sigmas; let partial_products = &proof.openings.partial_products; - let zeta_pow_deg = self.exp_power_of_2(zeta, inner_common_data.degree_bits); + let zeta_pow_deg = self.exp_power_of_2_extension(zeta, inner_common_data.degree_bits); let vanishing_polys_zeta = context!( self, "evaluate the vanishing polynomial at our challenge point, zeta.", @@ -89,7 +96,7 @@ impl, const D: usize> CircuitBuilder { { let recombined_quotient = scale.reduce(chunk, self); let computed_vanishing_poly = self.mul_extension(z_h_zeta, recombined_quotient); - self.named_route_extension( + self.named_assert_equal_extension( vanishing_polys_zeta[i], computed_vanishing_poly, format!("Vanishing polynomial == Z_H * quotient, challenge {}", i), @@ -127,7 +134,7 @@ mod tests { use crate::polynomial::commitment::OpeningProofTarget; use crate::proof::{ FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget, FriQueryStepTarget, - OpeningSetTarget, Proof, + OpeningSetTarget, Proof, ProofTarget, ProofWithPublicInputs, }; use crate::verifier::verify; use crate::witness::PartialWitness; @@ -167,9 +174,14 @@ mod tests { // Construct a `ProofTarget` with the same dimensions as `proof`. fn proof_to_proof_target, const D: usize>( - proof: &Proof, + proof_with_pis: &ProofWithPublicInputs, builder: &mut CircuitBuilder, - ) -> ProofTarget { + ) -> ProofWithPublicInputsTarget { + let ProofWithPublicInputs { + proof, + public_inputs, + } = proof_with_pis; + let wires_root = builder.add_virtual_hash(); let plonk_zs_root = builder.add_virtual_hash(); let quotient_polys_root = builder.add_virtual_hash(); @@ -208,21 +220,41 @@ mod tests { }, }; - ProofTarget { + let proof = ProofTarget { wires_root, plonk_zs_partial_products_root: plonk_zs_root, quotient_polys_root, openings, opening_proof, + }; + + let public_inputs = builder.add_virtual_targets(public_inputs.len()); + ProofWithPublicInputsTarget { + proof, + public_inputs, } } // Set the targets in a `ProofTarget` to their corresponding values in a `Proof`. fn set_proof_target, const D: usize>( - proof: &Proof, - pt: &ProofTarget, + proof: &ProofWithPublicInputs, + pt: &ProofWithPublicInputsTarget, pw: &mut PartialWitness, ) { + let ProofWithPublicInputs { + proof, + public_inputs, + } = proof; + let ProofWithPublicInputsTarget { + proof: pt, + public_inputs: pi_targets, + } = pt; + + // Set public inputs. + for (&pi_t, &pi) in pi_targets.iter().zip(public_inputs) { + pw.set_target(pi_t, pi); + } + pw.set_hash_target(pt.wires_root, proof.wires_root); pw.set_hash_target( pt.plonk_zs_partial_products_root, @@ -343,7 +375,7 @@ mod tests { num_query_rounds: 40, }, }; - let (proof, vd, cd) = { + let (proof_with_pis, vd, cd) = { let mut builder = CircuitBuilder::::new(config.clone()); let _two = builder.two(); let _two = builder.hash_n_to_hash(vec![_two], true).elements[0]; @@ -357,12 +389,12 @@ mod tests { data.common, ) }; - verify(proof.clone(), &vd, &cd)?; + verify(proof_with_pis.clone(), &vd, &cd)?; let mut builder = CircuitBuilder::::new(config.clone()); let mut pw = PartialWitness::new(); - let pt = proof_to_proof_target(&proof, &mut builder); - set_proof_target(&proof, &pt, &mut pw); + let pt = proof_to_proof_target(&proof_with_pis, &mut builder); + set_proof_target(&proof_with_pis, &pt, &mut pw); let inner_data = VerifierCircuitTarget { constants_sigmas_root: builder.add_virtual_hash(), diff --git a/src/target.rs b/src/target.rs index 52be8b5a..857a4378 100644 --- a/src/target.rs +++ b/src/target.rs @@ -7,9 +7,6 @@ use crate::wire::Wire; #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub enum Target { Wire(Wire), - PublicInput { - index: usize, - }, /// A target that doesn't have any inherent location in the witness (but it can be copied to /// another target that does). This is useful for representing intermediate values in witness /// generation. @@ -26,7 +23,6 @@ impl Target { pub fn is_routable(&self, config: &CircuitConfig) -> bool { match self { Target::Wire(wire) => wire.is_routable(config), - Target::PublicInput { .. } => true, Target::VirtualTarget { .. } => true, } } diff --git a/src/util/mod.rs b/src/util/mod.rs index bfebe058..83a97881 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -51,7 +51,7 @@ pub(crate) fn transpose(matrix: &[Vec]) -> Vec> { } /// Permutes `arr` such that each index is mapped to its reverse in binary. -pub(crate) fn reverse_index_bits(arr: Vec) -> Vec { +pub(crate) fn reverse_index_bits(arr: &[T]) -> Vec { let n = arr.len(); let n_power = log2_strict(n); @@ -99,12 +99,9 @@ mod tests { #[test] fn test_reverse_index_bits() { + assert_eq!(reverse_index_bits(&[10, 20, 30, 40]), vec![10, 30, 20, 40]); assert_eq!( - reverse_index_bits(vec![10, 20, 30, 40]), - vec![10, 30, 20, 40] - ); - assert_eq!( - reverse_index_bits(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + reverse_index_bits(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), vec![0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15] ); } diff --git a/src/util/scaling.rs b/src/util/scaling.rs index be339784..5cbffb83 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -16,7 +16,7 @@ use crate::polynomial::polynomial::PolynomialCoeffs; /// This struct abstract away these operations by implementing Horner's method and keeping track /// of the number of multiplications by `a` to compute the scaling factor. /// See https://github.com/mir-protocol/plonky2/pull/69 for more details and discussions. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ReducingFactor { base: F, count: u64, @@ -79,7 +79,7 @@ impl ReducingFactor { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ReducingFactorTarget { base: ExtensionTarget, count: u64, @@ -122,8 +122,10 @@ impl ReducingFactorTarget { // out_0 = alpha acc + pair[0] // acc' = out_1 = alpha out_0 + pair[1] let gate = builder.num_gates(); - let out_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); + let out_0 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_first_output(), + ); acc = builder .double_arithmetic_extension( F::ONE, @@ -131,6 +133,7 @@ impl ReducingFactorTarget { self.base, acc, pair[0], + self.base, out_0, pair[1], ) diff --git a/src/vanishing_poly.rs b/src/vanishing_poly.rs index b82a23c1..540ab86b 100644 --- a/src/vanishing_poly.rs +++ b/src/vanishing_poly.rs @@ -236,10 +236,13 @@ pub fn evaluate_gate_constraints_recursively, const D: usize>( ) -> Vec> { let mut constraints = vec![builder.zero_extension(); num_gate_constraints]; for gate in gates { - let gate_constraints = gate - .gate - .0 - .eval_filtered_recursively(builder, vars, &gate.prefix); + let gate_constraints = context!( + builder, + &format!("evaluate {} constraints", gate.gate.0.id()), + gate.gate + .0 + .eval_filtered_recursively(builder, vars, &gate.prefix) + ); for (i, c) in gate_constraints.into_iter().enumerate() { constraints[i] = builder.add_extension(constraints[i], c); } diff --git a/src/vars.rs b/src/vars.rs index 17d51051..8e98d41f 100644 --- a/src/vars.rs +++ b/src/vars.rs @@ -5,17 +5,20 @@ use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::Extendable; use crate::field::field::Field; +use crate::proof::{Hash, HashTarget}; #[derive(Debug, Copy, Clone)] pub struct EvaluationVars<'a, F: Extendable, const D: usize> { pub(crate) local_constants: &'a [F::Extension], pub(crate) local_wires: &'a [F::Extension], + pub(crate) public_inputs_hash: &'a Hash, } #[derive(Debug, Copy, Clone)] pub struct EvaluationVarsBase<'a, F: Field> { pub(crate) local_constants: &'a [F], pub(crate) local_wires: &'a [F], + pub(crate) public_inputs_hash: &'a Hash, } impl<'a, F: Extendable, const D: usize> EvaluationVars<'a, F, D> { @@ -49,6 +52,7 @@ impl<'a, const D: usize> EvaluationTargets<'a, D> { pub struct EvaluationTargets<'a, const D: usize> { pub(crate) local_constants: &'a [ExtensionTarget], pub(crate) local_wires: &'a [ExtensionTarget], + pub(crate) public_inputs_hash: &'a HashTarget, } impl<'a, const D: usize> EvaluationTargets<'a, D> { diff --git a/src/verifier.rs b/src/verifier.rs index 6b4df627..878b630a 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -3,20 +3,27 @@ use anyhow::{ensure, Result}; use crate::circuit_data::{CommonCircuitData, VerifierOnlyCircuitData}; use crate::field::extension_field::Extendable; use crate::field::field::Field; +use crate::hash::hash_n_to_hash; use crate::plonk_challenger::Challenger; use crate::plonk_common::reduce_with_powers; -use crate::proof::Proof; +use crate::proof::ProofWithPublicInputs; use crate::vanishing_poly::eval_vanishing_poly; use crate::vars::EvaluationVars; pub(crate) fn verify, const D: usize>( - proof: Proof, + proof_with_pis: ProofWithPublicInputs, verifier_data: &VerifierOnlyCircuitData, common_data: &CommonCircuitData, ) -> Result<()> { + let ProofWithPublicInputs { + proof, + public_inputs, + } = proof_with_pis; let config = &common_data.config; let num_challenges = config.num_challenges; + let public_inputs_hash = &hash_n_to_hash(public_inputs, true); + let mut challenger = Challenger::new(); // Observe the instance. // TODO: Need to include public inputs as well. @@ -37,6 +44,7 @@ pub(crate) fn verify, const D: usize>( let vars = EvaluationVars { local_constants, local_wires, + public_inputs_hash, }; let local_zs = &proof.openings.plonk_zs; let next_zs = &proof.openings.plonk_zs_right; diff --git a/src/witness.rs b/src/witness.rs index 25885ba3..ce4a95af 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -8,6 +8,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::gates::gate::GateInstance; +use crate::generator::GeneratedValues; use crate::proof::{Hash, HashTarget}; use crate::target::Target; use crate::wire::Wire; @@ -35,28 +36,6 @@ impl PartialWitness { } } - pub fn singleton_wire(wire: Wire, value: F) -> Self { - Self::singleton_target(Target::Wire(wire), value) - } - - pub fn singleton_target(target: Target, value: F) -> Self { - let mut witness = PartialWitness::new(); - witness.set_target(target, value); - witness - } - - pub fn singleton_extension_target( - et: ExtensionTarget, - value: F::Extension, - ) -> Self - where - F: Extendable, - { - let mut witness = PartialWitness::new(); - witness.set_extension_target(et, value); - witness - } - pub fn is_empty(&self) -> bool { self.target_values.is_empty() } @@ -157,7 +136,7 @@ impl PartialWitness { self.set_wires(wires, &value.to_basefield_array()); } - pub fn extend(&mut self, other: PartialWitness) { + pub fn extend(&mut self, other: GeneratedValues) { for (target, value) in other.target_values { self.set_target(target, value); } @@ -193,7 +172,6 @@ impl PartialWitness { gate, gate_instances[*gate].gate_type.0.id() ), - Target::PublicInput { index } => format!("{}-th public input", index), Target::VirtualTarget { index } => format!("{}-th virtual target", index), } };