diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index 8867415d..8e798f72 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -50,8 +50,8 @@ fn bench_prove, const D: usize>() -> Result<()> { let circuit = builder.build(); let inputs = PartialWitness::new(); - let proof = circuit.prove(inputs)?; - let proof_bytes = serde_cbor::to_vec(&proof).unwrap(); + let proof_with_pis = circuit.prove(inputs)?; + let proof_bytes = serde_cbor::to_vec(&proof_with_pis).unwrap(); info!("Proof length: {} bytes", proof_bytes.len()); - circuit.verify(proof) + circuit.verify(proof_with_pis) } diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 2cfeb720..2e12ce0f 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -17,6 +17,7 @@ use crate::gates::constant::ConstantGate; use crate::gates::gate::{GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; use crate::gates::noop::NoopGate; +use crate::gates::public_input::PublicInputGate; use crate::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator}; use crate::hash::hash_n_to_hash; use crate::permutation_argument::TargetPartition; @@ -39,8 +40,8 @@ pub struct CircuitBuilder, const D: usize> { /// The concrete placement of each gate. gate_instances: Vec>, - /// The next available index for a public input. - public_input_index: usize, + /// Targets to be made public. + public_inputs: Vec, /// The next available index for a `VirtualTarget`. virtual_target_index: usize, @@ -66,7 +67,7 @@ impl, const D: usize> CircuitBuilder { config, gates: HashSet::new(), gate_instances: Vec::new(), - public_input_index: 0, + public_inputs: Vec::new(), virtual_target_index: 0, copy_constraints: Vec::new(), context_log: ContextTree::new(), @@ -81,14 +82,14 @@ impl, const D: usize> CircuitBuilder { self.gate_instances.len() } - pub fn add_public_input(&mut self) -> Target { - let index = self.public_input_index; - self.public_input_index += 1; - Target::PublicInput { index } + /// Registers the given target as a public input. + pub fn register_public_input(&mut self, target: Target) { + self.public_inputs.push(target); } - pub fn add_public_inputs(&mut self, n: usize) -> Vec { - (0..n).map(|_i| self.add_public_input()).collect() + /// Registers the given targets as public inputs. + pub fn register_public_inputs(&mut self, targets: &[Target]) { + targets.iter().for_each(|&t| self.register_public_input(t)); } /// Adds a new "virtual" target. This is not an actual wire in the witness, but just a target @@ -462,10 +463,7 @@ impl, const D: usize> CircuitBuilder { let degree_log = log2_strict(degree); let mut target_partition = TargetPartition::new(|t| match t { Target::Wire(Wire { gate, input }) => gate * self.config.num_routed_wires + input, - Target::PublicInput { index } => degree * self.config.num_routed_wires + index, - Target::VirtualTarget { index } => { - degree * self.config.num_routed_wires + self.public_input_index + index - } + Target::VirtualTarget { index } => degree * self.config.num_routed_wires + index, }); for gate in 0..degree { @@ -474,10 +472,6 @@ impl, const D: usize> CircuitBuilder { } } - for index in 0..self.public_input_index { - target_partition.add(Target::PublicInput { index }); - } - for index in 0..self.virtual_target_index { target_partition.add(Target::VirtualTarget { index }); } @@ -500,6 +494,19 @@ impl, const D: usize> CircuitBuilder { pub fn build(mut self) -> CircuitData { let quotient_degree_factor = 7; // TODO: add this as a parameter. let start = Instant::now(); + + // Hash the public inputs, and route them to a `PublicInputGate` which will enforce that + // those hash wires match the claimed public inputs. + let public_inputs_hash = self.hash_n_to_hash(self.public_inputs.clone(), true); + let pi_gate = self.add_gate_no_constants(PublicInputGate::get()); + for (&hash_part, wire) in public_inputs_hash + .elements + .iter() + .zip(PublicInputGate::wires_public_inputs_hash()) + { + self.route(hash_part, Target::wire(pi_gate, wire)) + } + info!( "Degree before blinding & padding: {}", self.gate_instances.len() @@ -552,6 +559,7 @@ impl, const D: usize> CircuitBuilder { subgroup, copy_constraints: self.copy_constraints, gate_instances: self.gate_instances, + public_inputs: self.public_inputs, marked_targets: self.marked_targets, }; diff --git a/src/circuit_data.rs b/src/circuit_data.rs index 9c68ca46..912f05cc 100644 --- a/src/circuit_data.rs +++ b/src/circuit_data.rs @@ -9,8 +9,9 @@ use crate::fri::FriConfig; use crate::gates::gate::{GateInstance, PrefixedGate}; use crate::generator::WitnessGenerator; use crate::polynomial::commitment::ListPolynomialCommitment; -use crate::proof::{Hash, HashTarget, Proof}; +use crate::proof::{Hash, HashTarget, ProofWithPublicInputs}; use crate::prover::prove; +use crate::target::Target; use crate::util::marking::MarkedTargets; use crate::verifier::verify; use crate::witness::PartialWitness; @@ -78,12 +79,12 @@ pub struct CircuitData, const D: usize> { } impl, const D: usize> CircuitData { - pub fn prove(&self, inputs: PartialWitness) -> Result> { + pub fn prove(&self, inputs: PartialWitness) -> Result> { prove(&self.prover_only, &self.common, inputs) } - pub fn verify(&self, proof: Proof) -> Result<()> { - verify(proof, &self.verifier_only, &self.common) + pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { + verify(proof_with_pis, &self.verifier_only, &self.common) } } @@ -100,7 +101,7 @@ pub struct ProverCircuitData, const D: usize> { } impl, const D: usize> ProverCircuitData { - pub fn prove(&self, inputs: PartialWitness) -> Result> { + pub fn prove(&self, inputs: PartialWitness) -> Result> { prove(&self.prover_only, &self.common, inputs) } } @@ -112,8 +113,8 @@ pub struct VerifierCircuitData, const D: usize> { } impl, const D: usize> VerifierCircuitData { - pub fn verify(&self, proof: Proof) -> Result<()> { - verify(proof, &self.verifier_only, &self.common) + pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { + verify(proof_with_pis, &self.verifier_only, &self.common) } } @@ -130,6 +131,8 @@ pub(crate) struct ProverOnlyCircuitData, const D: usize> { pub copy_constraints: Vec, /// The concrete placement of each gate in the circuit. pub gate_instances: Vec>, + /// Targets to be made public. + pub public_inputs: Vec, /// A vector of marked targets. The values assigned to these targets will be displayed by the prover. pub marked_targets: Vec>, } diff --git a/src/field/fft.rs b/src/field/fft.rs index 35b11852..3fd37b1c 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -67,7 +67,7 @@ fn fft_unrolled_root_table(n: usize) -> FftRootTable { #[inline] fn fft_dispatch( - input: Vec, + input: &[F], zero_factor: Option, root_table: Option>, ) -> Vec { @@ -87,13 +87,13 @@ fn fft_dispatch( } #[inline] -pub fn fft(poly: PolynomialCoeffs) -> PolynomialValues { +pub fn fft(poly: &PolynomialCoeffs) -> PolynomialValues { fft_with_options(poly, None, None) } #[inline] pub fn fft_with_options( - poly: PolynomialCoeffs, + poly: &PolynomialCoeffs, zero_factor: Option, root_table: Option>, ) -> PolynomialValues { @@ -104,12 +104,12 @@ pub fn fft_with_options( } #[inline] -pub fn ifft(poly: PolynomialValues) -> PolynomialCoeffs { +pub fn ifft(poly: &PolynomialValues) -> PolynomialCoeffs { ifft_with_options(poly, None, None) } pub fn ifft_with_options( - poly: PolynomialValues, + poly: &PolynomialValues, zero_factor: Option, root_table: Option>, ) -> PolynomialCoeffs { @@ -139,11 +139,7 @@ pub fn ifft_with_options( /// The parameter r signifies that the first 1/2^r of the entries of /// input may be non-zero, but the last 1 - 1/2^r entries are /// definitely zero. -pub(crate) fn fft_classic( - input: Vec, - r: usize, - root_table: FftRootTable, -) -> Vec { +pub(crate) fn fft_classic(input: &[F], r: usize, root_table: FftRootTable) -> Vec { let mut values = reverse_index_bits(input); let n = values.len(); @@ -196,7 +192,7 @@ pub(crate) fn fft_classic( /// The parameter r signifies that the first 1/2^r of the entries of /// input may be non-zero, but the last 1 - 1/2^r entries are /// definitely zero. -fn fft_unrolled(input: Vec, r_orig: usize, root_table: FftRootTable) -> Vec { +fn fft_unrolled(input: &[F], r_orig: usize, root_table: FftRootTable) -> Vec { let n = input.len(); let lg_n = log2_strict(input.len()); @@ -325,10 +321,10 @@ mod tests { } let coefficients = PolynomialCoeffs::new_padded(coefficients); - let points = fft(coefficients.clone()); + let points = fft(&coefficients); assert_eq!(points, evaluate_naive(&coefficients)); - let interpolated_coefficients = ifft(points); + let interpolated_coefficients = ifft(&points); for i in 0..degree { assert_eq!(interpolated_coefficients.coeffs[i], coefficients.coeffs[i]); } @@ -337,12 +333,9 @@ mod tests { } for r in 0..4 { - // expand ceofficients by factor 2^r by filling with zeros - let zero_tail = coefficients.clone().lde(r); - assert_eq!( - fft(zero_tail.clone()), - fft_with_options(zero_tail, Some(r), None) - ); + // expand coefficients by factor 2^r by filling with zeros + let zero_tail = coefficients.lde(r); + assert_eq!(fft(&zero_tail), fft_with_options(&zero_tail, Some(r), None)); } } @@ -350,10 +343,7 @@ mod tests { let degree = coefficients.len(); let degree_padded = 1 << log2_ceil(degree); - let mut coefficients_padded = coefficients.clone(); - for _i in degree..degree_padded { - coefficients_padded.coeffs.push(F::ZERO); - } + let coefficients_padded = coefficients.padded(degree_padded); evaluate_naive_power_of_2(&coefficients_padded) } diff --git a/src/field/interpolation.rs b/src/field/interpolation.rs index 3d5e609c..83414f1b 100644 --- a/src/field/interpolation.rs +++ b/src/field/interpolation.rs @@ -18,7 +18,7 @@ pub(crate) fn interpolant(points: &[(F, F)]) -> PolynomialCoeffs { .map(|x| interpolate(points, x, &barycentric_weights)) .collect(); - let mut coeffs = ifft(PolynomialValues { + let mut coeffs = ifft(&PolynomialValues { values: subgroup_evals, }); coeffs.trim(); diff --git a/src/fri/prover.rs b/src/fri/prover.rs index 1c642d17..5a8f09e5 100644 --- a/src/fri/prover.rs +++ b/src/fri/prover.rs @@ -16,9 +16,9 @@ use crate::util::reverse_index_bits_in_place; pub fn fri_proof, const D: usize>( initial_merkle_trees: &[&MerkleTree], // Coefficients of the polynomial on which the LDT is performed. Only the first `1/rate` coefficients are non-zero. - lde_polynomial_coeffs: &PolynomialCoeffs, + lde_polynomial_coeffs: PolynomialCoeffs, // Evaluation of the polynomial on the large domain. - lde_polynomial_values: &PolynomialValues, + lde_polynomial_values: PolynomialValues, challenger: &mut Challenger, config: &FriConfig, ) -> FriProof { @@ -53,14 +53,11 @@ pub fn fri_proof, const D: usize>( } fn fri_committed_trees, const D: usize>( - polynomial_coeffs: &PolynomialCoeffs, - polynomial_values: &PolynomialValues, + mut coeffs: PolynomialCoeffs, + mut values: PolynomialValues, challenger: &mut Challenger, config: &FriConfig, ) -> (Vec>, PolynomialCoeffs) { - let mut values = polynomial_values.clone(); - let mut coeffs = polynomial_coeffs.clone(); - let mut trees = Vec::new(); let mut shift = F::MULTIPLICATIVE_GROUP_GENERATOR; @@ -91,8 +88,7 @@ fn fri_committed_trees, const D: usize>( .collect::>(), ); shift = shift.exp_u32(arity as u32); - // TODO: Is it faster to interpolate? - values = coeffs.clone().coset_fft(shift.into()) + values = coeffs.coset_fft(shift.into()) } coeffs.trim(); diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index 3ed199dd..c2d5155e 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -246,7 +246,6 @@ impl, const D: usize> CircuitBuilder { ) { let config = &common_data.config.fri_config; let n_log = log2_strict(n); - let mut evaluations: Vec>> = Vec::new(); // TODO: Do we need to range check `x_index` to a target smaller than `p`? let mut x_index = challenger.get_challenge(self); x_index = self.split_low_high(x_index, n_log, 64).0; @@ -273,6 +272,7 @@ impl, const D: usize> CircuitBuilder { self.mul(g, phi) }); + let mut evaluations: Vec>> = Vec::new(); for (i, &arity_bits) in config.reduction_arity_bits.iter().enumerate() { let next_domain_size = domain_size >> arity_bits; let e_x = if i == 0 { @@ -308,23 +308,21 @@ impl, const D: usize> CircuitBuilder { let (low_x_index, high_x_index) = self.split_low_high(x_index, arity_bits, x_index_num_bits); evals = self.insert(low_x_index, e_x, evals); - evaluations.push(evals); context!( self, "verify FRI round Merkle proof.", self.verify_merkle_proof( - flatten_target(&evaluations[i]), + flatten_target(&evals), high_x_index, proof.commit_phase_merkle_roots[i], &round_proof.steps[i].merkle_proof, ) ); + evaluations.push(evals); if i > 0 { // Update the point x to x^arity. - for _ in 0..config.reduction_arity_bits[i - 1] { - subgroup_x = self.square(subgroup_x); - } + subgroup_x = self.exp_power_of_2(subgroup_x, config.reduction_arity_bits[i - 1]); } domain_size = next_domain_size; old_x_index = low_x_index; @@ -345,9 +343,7 @@ impl, const D: usize> CircuitBuilder { *betas.last().unwrap(), ) ); - for _ in 0..final_arity_bits { - subgroup_x = self.square(subgroup_x); - } + subgroup_x = self.exp_power_of_2(subgroup_x, final_arity_bits); // Final check of FRI. After all the reductions, we check that the final polynomial is equal // to the one sent by the prover. diff --git a/src/fri/verifier.rs b/src/fri/verifier.rs index 22e0f24c..c05db9cd 100644 --- a/src/fri/verifier.rs +++ b/src/fri/verifier.rs @@ -248,7 +248,6 @@ fn fri_verifier_query_round, const D: usize>( common_data: &CommonCircuitData, ) -> Result<()> { let config = &common_data.config.fri_config; - let mut evaluations: Vec> = Vec::new(); let x = challenger.get_challenge(); let mut domain_size = n; let mut x_index = x.to_canonical_u64() as usize % n; @@ -262,6 +261,8 @@ fn fri_verifier_query_round, const D: usize>( let log_n = log2_strict(n); let mut subgroup_x = F::MULTIPLICATIVE_GROUP_GENERATOR * F::primitive_root_of_unity(log_n).exp(reverse_bits(x_index, log_n) as u64); + + let mut evaluations: Vec> = Vec::new(); for (i, &arity_bits) in config.reduction_arity_bits.iter().enumerate() { let arity = 1 << arity_bits; let next_domain_size = domain_size >> arity_bits; @@ -288,20 +289,18 @@ fn fri_verifier_query_round, const D: usize>( let mut evals = round_proof.steps[i].evals.clone(); // Insert P(y) into the evaluation vector, since it wasn't included by the prover. evals.insert(x_index & (arity - 1), e_x); - evaluations.push(evals); verify_merkle_proof( - flatten(&evaluations[i]), + flatten(&evals), x_index >> arity_bits, proof.commit_phase_merkle_roots[i], &round_proof.steps[i].merkle_proof, false, )?; + evaluations.push(evals); if i > 0 { // Update the point x to x^arity. - for _ in 0..config.reduction_arity_bits[i - 1] { - subgroup_x = subgroup_x.square(); - } + subgroup_x = subgroup_x.exp_power_of_2(config.reduction_arity_bits[i - 1]); } domain_size = next_domain_size; old_x_index = x_index & (arity - 1); @@ -317,9 +316,7 @@ fn fri_verifier_query_round, const D: usize>( last_evals, *betas.last().unwrap(), ); - for _ in 0..final_arity_bits { - subgroup_x = subgroup_x.square(); - } + subgroup_x = subgroup_x.exp_power_of_2(final_arity_bits); // Final check of FRI. After all the reductions, we check that the final polynomial is equal // to the one sent by the prover. diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index f20b2f01..c60572fc 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -156,6 +156,15 @@ impl, const D: usize> CircuitBuilder { self.mul_many_extension(&terms_ext).to_target_array()[0] } + /// Exponentiate `base` to the power of `2^power_log`. + // TODO: Test + pub fn exp_power_of_2(&mut self, mut base: Target, power_log: usize) -> Target { + for _ in 0..power_log { + base = self.square(base); + } + base + } + // TODO: Optimize this, maybe with a new gate. // TODO: Test /// Exponentiate `base` to the power of `exponent`, where `exponent < 2^num_bits`. diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 22f70884..9aedc2fe 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -7,7 +7,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::{Extendable, OEF}; use crate::gates::arithmetic::ArithmeticExtensionGate; -use crate::generator::SimpleGenerator; +use crate::generator::{GeneratedValues, SimpleGenerator}; use crate::target::Target; use crate::util::bits_u64; use crate::witness::PartialWitness; @@ -357,7 +357,7 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of `2^power_log`. // TODO: Test - pub fn exp_power_of_2( + pub fn exp_power_of_2_extension( &mut self, mut base: ExtensionTarget, power_log: usize, @@ -449,11 +449,11 @@ impl, const D: usize> SimpleGenerator for QuotientGeneratorE deps } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let num = witness.get_extension_target(self.numerator); let dem = witness.get_extension_target(self.denominator); let quotient = num / dem; - PartialWitness::singleton_extension_target(self.quotient, quotient) + GeneratedValues::singleton_extension_target(self.quotient, quotient) } } diff --git a/src/gadgets/range_check.rs b/src/gadgets/range_check.rs index 7fd35efc..c0848af8 100644 --- a/src/gadgets/range_check.rs +++ b/src/gadgets/range_check.rs @@ -2,7 +2,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::base_sum::BaseSumGate; -use crate::generator::SimpleGenerator; +use crate::generator::{GeneratedValues, SimpleGenerator}; use crate::target::Target; use crate::witness::PartialWitness; @@ -49,12 +49,12 @@ impl SimpleGenerator for LowHighGenerator { vec![self.integer] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let integer_value = witness.get_target(self.integer).to_canonical_u64(); let low = integer_value & ((1 << self.n_log) - 1); let high = integer_value >> self.n_log; - let mut result = PartialWitness::new(); + let mut result = GeneratedValues::with_capacity(2); result.set_target(self.low, F::from_canonical_u64(low)); result.set_target(self.high, F::from_canonical_u64(high)); diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 3a2c27f4..9cc6ab7c 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -2,7 +2,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::base_sum::BaseSumGate; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::util::ceil_div_usize; use crate::wire::Wire; @@ -110,10 +110,10 @@ impl SimpleGenerator for SplitGenerator { vec![self.integer] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let mut integer_value = witness.get_target(self.integer).to_canonical_u64(); - let mut result = PartialWitness::new(); + let mut result = GeneratedValues::with_capacity(self.bits.len()); for &b in &self.bits { let b_value = integer_value & 1; result.set_target(b, F::from_canonical_u64(b_value)); @@ -141,10 +141,10 @@ impl SimpleGenerator for WireSplitGenerator { vec![self.integer] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let mut integer_value = witness.get_target(self.integer).to_canonical_u64(); - let mut result = PartialWitness::new(); + let mut result = GeneratedValues::with_capacity(self.gates.len()); for &gate in &self.gates { let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); result.set_target( diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 2548c498..cf39e09b 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -4,7 +4,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::witness::PartialWitness; @@ -169,7 +169,7 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio .collect() } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let extract_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.gate_index, range); witness.get_extension_target(t) @@ -189,7 +189,7 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio let computed_output = multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into(); - PartialWitness::singleton_extension_target(output_target, computed_output) + GeneratedValues::singleton_extension_target(output_target, computed_output) } } @@ -202,7 +202,7 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio .collect() } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let extract_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.gate_index, range); witness.get_extension_target(t) @@ -222,7 +222,7 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio let computed_output = multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into(); - PartialWitness::singleton_extension_target(output_target, computed_output) + GeneratedValues::singleton_extension_target(output_target, computed_output) } } diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index 8f453d8e..8ad189ee 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -5,7 +5,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::plonk_common::{reduce_with_powers, reduce_with_powers_recursive}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; @@ -130,7 +130,7 @@ impl SimpleGenerator for BaseSplitGenerator { vec![Target::wire(self.gate_index, BaseSumGate::::WIRE_SUM)] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let sum_value = witness .get_target(Target::wire(self.gate_index, BaseSumGate::::WIRE_SUM)) .to_canonical_u64() as usize; @@ -155,7 +155,7 @@ impl SimpleGenerator for BaseSplitGenerator { .iter() .fold(F::ZERO, |acc, &x| acc * b_field + x); - let mut result = PartialWitness::new(); + let mut result = GeneratedValues::with_capacity(self.num_limbs + 1); result.set_target( Target::wire(self.gate_index, BaseSumGate::::WIRE_REVERSED_SUM), reversed_sum, diff --git a/src/gates/constant.rs b/src/gates/constant.rs index 3845031a..4049d058 100644 --- a/src/gates/constant.rs +++ b/src/gates/constant.rs @@ -3,7 +3,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; @@ -83,12 +83,12 @@ impl SimpleGenerator for ConstantGenerator { Vec::new() } - fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, _witness: &PartialWitness) -> GeneratedValues { let wire = Wire { gate: self.gate_index, input: ConstantGate::WIRE_OUTPUT, }; - PartialWitness::singleton_target(Target::Wire(wire), self.constant) + GeneratedValues::singleton_target(Target::Wire(wire), self.constant) } } diff --git a/src/gates/gate.rs b/src/gates/gate.rs index 9b45817e..17059332 100644 --- a/src/gates/gate.rs +++ b/src/gates/gate.rs @@ -31,9 +31,11 @@ pub trait Gate, const D: usize>: 'static + Send + Sync { .iter() .map(|w| F::Extension::from_basefield(*w)) .collect::>(); + let public_inputs_hash = &vars_base.public_inputs_hash; let vars = EvaluationVars { local_constants, local_wires, + public_inputs_hash, }; let values = self.eval_unfiltered(vars); diff --git a/src/gates/gate_testing.rs b/src/gates/gate_testing.rs index cb25acb8..2f7a3020 100644 --- a/src/gates/gate_testing.rs +++ b/src/gates/gate_testing.rs @@ -2,6 +2,7 @@ use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::GateRef; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::proof::Hash; use crate::util::{log2_ceil, transpose}; use crate::vars::EvaluationVars; @@ -17,6 +18,7 @@ pub(crate) fn test_low_degree, const D: usize>(gate: GateRef(gate.num_wires(), rate_bits); let constant_ldes = random_low_degree_matrix::(gate.num_constants(), rate_bits); assert_eq!(wire_ldes.len(), constant_ldes.len()); + let public_inputs_hash = &Hash::rand(); let constraint_evals = wire_ldes .iter() @@ -24,6 +26,7 @@ pub(crate) fn test_low_degree, const D: usize>(gate: GateRef>(); diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 768a5693..0404884b 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -5,7 +5,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::gmimc::gmimc_automatic_constants; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; @@ -239,8 +239,8 @@ impl, const D: usize, const R: usize> SimpleGenerator .collect() } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let mut result = PartialWitness::new(); + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + let mut result = GeneratedValues::with_capacity(R + W + 1); let mut state = (0..W) .map(|i| { @@ -326,6 +326,7 @@ mod tests { use crate::gates::gmimc::{GMiMCGate, W}; use crate::generator::generate_partial_witness; use crate::gmimc::gmimc_permute_naive; + use crate::proof::Hash; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::verifier::verify; use crate::wire::Wire; @@ -416,9 +417,11 @@ mod tests { let gate = Gate::with_constants(constants); let wires = FF::rand_vec(Gate::end()); + let public_inputs_hash = &Hash::rand(); let vars = EvaluationVars { local_constants: &[], local_wires: &wires, + public_inputs_hash, }; let ev = gate.0.eval_unfiltered(vars); @@ -427,9 +430,14 @@ mod tests { for i in 0..Gate::end() { pw.set_extension_target(wires_t[i], wires[i]); } + + let public_inputs_hash_t = builder.add_virtual_hash(); + pw.set_hash_target(public_inputs_hash_t, *public_inputs_hash); + let vars_t = EvaluationTargets { local_constants: &[], local_wires: &wires_t, + public_inputs_hash: &public_inputs_hash_t, }; let ev_t = gate.0.eval_unfiltered_recursively(&mut builder, vars_t); diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index 64301347..1bc0b454 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -7,7 +7,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; @@ -218,7 +218,7 @@ impl, const D: usize> SimpleGenerator for InsertionGenerator deps } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let local_wire = |input| Wire { gate: self.gate_index, input, @@ -264,7 +264,7 @@ impl, const D: usize> SimpleGenerator for InsertionGenerator let mut insert_here_vals = vec![F::ZERO; vec_size]; insert_here_vals.insert(insertion_index, F::ONE); - let mut result = PartialWitness::::new(); + let mut result = GeneratedValues::::with_capacity((vec_size + 1) * (D + 2)); for i in 0..=vec_size { let output_wires = self.gate.wires_output_list_item(i).map(local_wire); result.set_ext_wires(output_wires, new_vec[i]); @@ -288,6 +288,7 @@ mod tests { use crate::gates::gate::Gate; use crate::gates::gate_testing::test_low_degree; use crate::gates::insertion::InsertionGate; + use crate::proof::Hash; use crate::vars::EvaluationVars; #[test] @@ -366,6 +367,7 @@ mod tests { let vars = EvaluationVars { local_constants: &[], local_wires: &get_wires(orig_vec, insertion_index, element_to_insert), + public_inputs_hash: &Hash::rand(), }; assert!( diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index ccf8d57d..17d34e3a 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -9,7 +9,7 @@ use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::interpolation::interpolant; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; @@ -216,7 +216,7 @@ impl, const D: usize> SimpleGenerator for InterpolationGener deps } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let n = self.gate.num_points; let local_wire = |input| Wire { @@ -244,7 +244,7 @@ impl, const D: usize> SimpleGenerator for InterpolationGener .collect::>(); let interpolant = interpolant(&points); - let mut result = PartialWitness::::new(); + let mut result = GeneratedValues::::with_capacity(D * (self.gate.num_points + 1)); for (i, &coeff) in interpolant.coeffs.iter().enumerate() { let wires = self.gate.wires_coeff(i).map(local_wire); result.set_ext_wires(wires, coeff); @@ -271,6 +271,7 @@ mod tests { use crate::gates::gate_testing::test_low_degree; use crate::gates::interpolation::InterpolationGate; use crate::polynomial::polynomial::PolynomialCoeffs; + use crate::proof::Hash; use crate::vars::EvaluationVars; #[test] @@ -352,6 +353,7 @@ mod tests { let vars = EvaluationVars { local_constants: &[], local_wires: &get_wires(2, coeffs, points, eval_point), + public_inputs_hash: &Hash::rand(), }; assert!( diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 6b3f05fa..441383ec 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -10,6 +10,7 @@ pub mod gmimc; pub mod insertion; pub mod interpolation; pub(crate) mod noop; +pub(crate) mod public_input; #[cfg(test)] mod gate_testing; diff --git a/src/gates/public_input.rs b/src/gates/public_input.rs new file mode 100644 index 00000000..a86b78d5 --- /dev/null +++ b/src/gates/public_input.rs @@ -0,0 +1,84 @@ +use std::ops::Range; + +use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::gates::gate::{Gate, GateRef}; +use crate::generator::WitnessGenerator; +use crate::vars::{EvaluationTargets, EvaluationVars}; + +/// A gate whose first four wires will be equal to a hash of public inputs. +pub struct PublicInputGate; + +impl PublicInputGate { + pub fn get, const D: usize>() -> GateRef { + GateRef::new(PublicInputGate) + } + + pub fn wires_public_inputs_hash() -> Range { + 0..4 + } +} + +impl, const D: usize> Gate for PublicInputGate { + fn id(&self) -> String { + "PublicInputGate".into() + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + Self::wires_public_inputs_hash() + .zip(vars.public_inputs_hash.elements) + .map(|(wire, hash_part)| vars.local_wires[wire] - hash_part.into()) + .collect() + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + Self::wires_public_inputs_hash() + .zip(vars.public_inputs_hash.elements) + .map(|(wire, hash_part)| { + let hash_part_ext = builder.convert_to_ext(hash_part); + builder.sub_extension(vars.local_wires[wire], hash_part_ext) + }) + .collect() + } + + fn generators( + &self, + _gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + Vec::new() + } + + fn num_wires(&self) -> usize { + 4 + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 + } + + fn num_constraints(&self) -> usize { + 4 + } +} + +#[cfg(test)] +mod tests { + use crate::field::crandall_field::CrandallField; + use crate::gates::gate_testing::test_low_degree; + use crate::gates::public_input::PublicInputGate; + + #[test] + fn low_degree() { + test_low_degree(PublicInputGate::get::()) + } +} diff --git a/src/generator.rs b/src/generator.rs index a47c5267..a7359a7d 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -1,8 +1,13 @@ use std::collections::{HashMap, HashSet}; +use std::convert::identity; use std::fmt::Debug; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; +use crate::proof::{Hash, HashTarget}; use crate::target::Target; +use crate::wire::Wire; use crate::witness::PartialWitness; /// Given a `PartialWitness` that has only inputs set, populates the rest of the witness using the @@ -27,7 +32,7 @@ pub(crate) fn generate_partial_witness( let mut pending_generator_indices: HashSet<_> = (0..generators.len()).collect(); // We also track a list of "expired" generators which have already returned false. - let mut expired_generator_indices = HashSet::new(); + let mut generator_is_expired = vec![false; generators.len()]; // Keep running generators until no generators are queued. while !pending_generator_indices.is_empty() { @@ -36,15 +41,15 @@ pub(crate) fn generate_partial_witness( for &generator_idx in &pending_generator_indices { let (result, finished) = generators[generator_idx].run(&witness); if finished { - expired_generator_indices.insert(generator_idx); + generator_is_expired[generator_idx] = true; } // Enqueue unfinished generators that were watching one of the newly populated targets. - for watch in result.target_values.keys() { + for (watch, _) in &result.target_values { if let Some(watching_generator_indices) = generator_indices_by_watches.get(watch) { - for watching_generator_idx in watching_generator_indices { - if !expired_generator_indices.contains(watching_generator_idx) { - next_pending_generator_indices.insert(*watching_generator_idx); + for &watching_generator_idx in watching_generator_indices { + if !generator_is_expired[watching_generator_idx] { + next_pending_generator_indices.insert(watching_generator_idx); } } } @@ -55,9 +60,9 @@ pub(crate) fn generate_partial_witness( pending_generator_indices = next_pending_generator_indices; } - assert_eq!( - expired_generator_indices.len(), - generators.len(), + + assert!( + generator_is_expired.into_iter().all(identity), "Some generators weren't run." ); } @@ -72,14 +77,101 @@ pub trait WitnessGenerator: 'static + Send + Sync { /// flag indicating whether the generator is finished. If the flag is true, the generator will /// never be run again, otherwise it will be queued for another run next time a target in its /// watch list is populated. - fn run(&self, witness: &PartialWitness) -> (PartialWitness, bool); + fn run(&self, witness: &PartialWitness) -> (GeneratedValues, bool); +} + +/// Values generated by a generator invocation. +pub struct GeneratedValues { + pub(crate) target_values: Vec<(Target, F)>, +} + +impl From> for GeneratedValues { + fn from(target_values: Vec<(Target, F)>) -> Self { + Self { target_values } + } +} + +impl GeneratedValues { + pub fn with_capacity(capacity: usize) -> Self { + Vec::with_capacity(capacity).into() + } + + pub fn empty() -> Self { + Vec::new().into() + } + + pub fn singleton_wire(wire: Wire, value: F) -> Self { + Self::singleton_target(Target::Wire(wire), value) + } + + pub fn singleton_target(target: Target, value: F) -> Self { + vec![(target, value)].into() + } + + pub fn singleton_extension_target( + et: ExtensionTarget, + value: F::Extension, + ) -> Self + where + F: Extendable, + { + let mut witness = Self::with_capacity(D); + witness.set_extension_target(et, value); + witness + } + + pub fn set_target(&mut self, target: Target, value: F) { + self.target_values.push((target, value)) + } + + pub fn set_hash_target(&mut self, ht: HashTarget, value: Hash) { + ht.elements + .iter() + .zip(value.elements) + .for_each(|(&t, x)| self.set_target(t, x)); + } + + pub fn set_extension_target( + &mut self, + et: ExtensionTarget, + value: F::Extension, + ) where + F: Extendable, + { + let limbs = value.to_basefield_array(); + (0..D).for_each(|i| { + self.set_target(et.0[i], limbs[i]); + }); + } + + pub fn set_wire(&mut self, wire: Wire, value: F) { + self.set_target(Target::Wire(wire), value) + } + + pub fn set_wires(&mut self, wires: W, values: &[F]) + where + W: IntoIterator, + { + // If we used itertools, we could use zip_eq for extra safety. + for (wire, &value) in wires.into_iter().zip(values) { + self.set_wire(wire, value); + } + } + + pub fn set_ext_wires(&mut self, wires: W, value: F::Extension) + where + F: Extendable, + W: IntoIterator, + { + self.set_wires(wires, &value.to_basefield_array()); + } } /// A generator which runs once after a list of dependencies is present in the witness. pub trait SimpleGenerator: 'static + Send + Sync { fn dependencies(&self) -> Vec; - fn run_once(&self, witness: &PartialWitness) -> PartialWitness; + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues; } impl> WitnessGenerator for SG { @@ -87,11 +179,11 @@ impl> WitnessGenerator for SG { self.dependencies() } - fn run(&self, witness: &PartialWitness) -> (PartialWitness, bool) { + fn run(&self, witness: &PartialWitness) -> (GeneratedValues, bool) { if witness.contains_all(&self.dependencies()) { (self.run_once(witness), true) } else { - (PartialWitness::new(), false) + (GeneratedValues::empty(), false) } } } @@ -108,9 +200,9 @@ impl SimpleGenerator for CopyGenerator { vec![self.src] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let value = witness.get_target(self.src); - PartialWitness::singleton_target(self.dst, value) + GeneratedValues::singleton_target(self.dst, value) } } @@ -124,10 +216,10 @@ impl SimpleGenerator for RandomValueGenerator { Vec::new() } - fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, _witness: &PartialWitness) -> GeneratedValues { let random_value = F::rand(); - PartialWitness::singleton_target(self.target, random_value) + GeneratedValues::singleton_target(self.target, random_value) } } @@ -142,7 +234,7 @@ impl SimpleGenerator for NonzeroTestGenerator { vec![self.to_test] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let to_test_value = witness.get_target(self.to_test); let dummy_value = if to_test_value == F::ZERO { @@ -151,6 +243,6 @@ impl SimpleGenerator for NonzeroTestGenerator { to_test_value.inverse() }; - PartialWitness::singleton_target(self.dummy, dummy_value) + GeneratedValues::singleton_target(self.dummy, dummy_value) } } diff --git a/src/permutation_argument.rs b/src/permutation_argument.rs index 3c98c801..7bad6d8d 100644 --- a/src/permutation_argument.rs +++ b/src/permutation_argument.rs @@ -90,7 +90,7 @@ impl usize> TargetPartition { } let mut indices = HashMap::new(); - // // Here we keep just the Wire targets, filtering out everything else. + // Here we keep just the Wire targets, filtering out everything else. let partition = partition .into_values() .map(|v| { diff --git a/src/polynomial/commitment.rs b/src/polynomial/commitment.rs index e50839e7..6bbba08d 100644 --- a/src/polynomial/commitment.rs +++ b/src/polynomial/commitment.rs @@ -34,10 +34,7 @@ impl ListPolynomialCommitment { /// Creates a list polynomial commitment for the polynomials interpolating the values in `values`. pub fn new(values: Vec>, rate_bits: usize, blinding: bool) -> Self { let degree = values[0].len(); - let polynomials = values - .par_iter() - .map(|v| v.clone().ifft()) - .collect::>(); + let polynomials = values.par_iter().map(|v| v.ifft()).collect::>(); let lde_values = timed!( Self::lde_values(&polynomials, rate_bits, blinding), "to compute LDE" @@ -92,7 +89,7 @@ impl ListPolynomialCommitment { .par_iter() .map(|p| { assert_eq!(p.len(), degree, "Polynomial degree invalid."); - p.clone().lde(rate_bits).coset_fft(F::coset_shift()).values + p.lde(rate_bits).coset_fft(F::coset_shift()).values }) .chain(if blinding { // If blinding, salt with two random elements to each leaf vector. @@ -182,15 +179,15 @@ impl ListPolynomialCommitment { final_poly += zs_quotient; let lde_final_poly = final_poly.lde(config.rate_bits); - let lde_final_values = lde_final_poly.clone().coset_fft(F::coset_shift().into()); + let lde_final_values = lde_final_poly.coset_fft(F::coset_shift().into()); let fri_proof = fri_proof( &commitments .par_iter() .map(|c| &c.merkle_tree) .collect::>(), - &lde_final_poly, - &lde_final_values, + lde_final_poly, + lde_final_values, challenger, &config.fri_config, ); diff --git a/src/polynomial/division.rs b/src/polynomial/division.rs index 50e1f8a6..6fa636ea 100644 --- a/src/polynomial/division.rs +++ b/src/polynomial/division.rs @@ -88,7 +88,7 @@ impl PolynomialCoeffs { let root = F::primitive_root_of_unity(log2_strict(a.len())); // Equals to the evaluation of `a` on `{g.w^i}`. - let mut a_eval = fft(a); + let mut a_eval = fft(&a); // Compute the denominators `1/(g^n.w^(n*i) - 1)` using batch inversion. let denominator_g = g.exp(n as u64); let root_n = root.exp(n as u64); @@ -112,7 +112,7 @@ impl PolynomialCoeffs { *x *= d; }); // `p` is the interpolating polynomial of `a_eval` on `{w^i}`. - let mut p = ifft(a_eval); + let mut p = ifft(&a_eval); // We need to scale it by `g^(-i)` to get the interpolating polynomial of `a_eval` on `{g.w^i}`, // a.k.a `a/Z_H`. let g_inv = g.inverse(); diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index 5f9bccd9..3e11f200 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -33,12 +33,12 @@ impl PolynomialValues { self.values.len() } - pub fn ifft(self) -> PolynomialCoeffs { + pub fn ifft(&self) -> PolynomialCoeffs { ifft(self) } /// Returns the polynomial whose evaluation on the coset `shift*H` is `self`. - pub fn coset_ifft(self, shift: F) -> PolynomialCoeffs { + pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs { let mut shifted_coeffs = self.ifft(); shifted_coeffs .coeffs @@ -54,9 +54,9 @@ impl PolynomialValues { polys.into_iter().map(|p| p.lde(rate_bits)).collect() } - pub fn lde(self, rate_bits: usize) -> Self { + pub fn lde(&self, rate_bits: usize) -> Self { let coeffs = ifft(self).lde(rate_bits); - fft_with_options(coeffs, Some(rate_bits), None) + fft_with_options(&coeffs, Some(rate_bits), None) } pub fn degree(&self) -> usize { @@ -66,7 +66,7 @@ impl PolynomialValues { } pub fn degree_plus_one(&self) -> usize { - self.clone().ifft().degree_plus_one() + self.ifft().degree_plus_one() } } @@ -136,7 +136,7 @@ impl PolynomialCoeffs { .fold(F::ZERO, |acc, &c| acc * x + c) } - pub fn lde_multiple(polys: Vec, rate_bits: usize) -> Vec { + pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec { polys.into_iter().map(|p| p.lde(rate_bits)).collect() } @@ -194,16 +194,16 @@ impl PolynomialCoeffs { Self::new(self.trimmed().coeffs.into_iter().rev().collect()) } - pub fn fft(self) -> PolynomialValues { + pub fn fft(&self) -> PolynomialValues { fft(self) } /// Returns the evaluation of the polynomial on the coset `shift*H`. - pub fn coset_fft(self, shift: F) -> PolynomialValues { + pub fn coset_fft(&self, shift: F) -> PolynomialValues { let modified_poly: Self = shift .powers() - .zip(self.coeffs) - .map(|(r, c)| r * c) + .zip(&self.coeffs) + .map(|(r, &c)| r * c) .collect::>() .into(); modified_poly.fft() @@ -262,8 +262,7 @@ impl Sub for &PolynomialCoeffs { fn sub(self, rhs: Self) -> Self::Output { let len = max(self.len(), rhs.len()); - let mut coeffs = self.coeffs.clone(); - coeffs.resize(len, F::ZERO); + let mut coeffs = self.padded(len).coeffs; for (i, &c) in rhs.coeffs.iter().enumerate() { coeffs[i] -= c; } @@ -343,7 +342,7 @@ impl Mul for &PolynomialCoeffs { .zip(b_evals.values) .map(|(pa, pb)| pa * pb) .collect(); - ifft(mul_evals.into()) + ifft(&mul_evals.into()) } } @@ -390,7 +389,7 @@ mod tests { let n = 1 << k; let poly = PolynomialCoeffs::new(F::rand_vec(n)); let shift = F::rand(); - let coset_evals = poly.clone().coset_fft(shift).values; + let coset_evals = poly.coset_fft(shift).values; let generator = F::primitive_root_of_unity(k); let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) @@ -411,7 +410,7 @@ mod tests { let n = 1 << k; let evals = PolynomialValues::new(F::rand_vec(n)); let shift = F::rand(); - let coeffs = evals.clone().coset_ifft(shift); + let coeffs = evals.coset_ifft(shift); let generator = F::primitive_root_of_unity(k); let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) diff --git a/src/proof.rs b/src/proof.rs index 71edba03..47be8569 100644 --- a/src/proof.rs +++ b/src/proof.rs @@ -37,6 +37,12 @@ impl Hash { elements: [elements[0], elements[1], elements[2], elements[3]], } } + + pub(crate) fn rand() -> Self { + Self { + elements: [F::rand(), F::rand(), F::rand(), F::rand()], + } + } } /// Represents a ~256 bit hash output. @@ -79,6 +85,13 @@ pub struct Proof, const D: usize> { pub opening_proof: OpeningProof, } +#[derive(Serialize, Deserialize, Clone, Debug)] +#[serde(bound = "")] +pub struct ProofWithPublicInputs, const D: usize> { + pub proof: Proof, + pub public_inputs: Vec, +} + pub struct ProofTarget { pub wires_root: HashTarget, pub plonk_zs_partial_products_root: HashTarget, @@ -87,6 +100,11 @@ pub struct ProofTarget { pub opening_proof: OpeningProofTarget, } +pub struct ProofWithPublicInputsTarget { + pub proof: ProofTarget, + pub public_inputs: Vec, +} + /// Evaluations and Merkle proof produced by the prover in a FRI query step. #[derive(Serialize, Deserialize, Clone, Debug)] #[serde(bound = "")] diff --git a/src/prover.rs b/src/prover.rs index e7c85130..f1cef856 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -7,11 +7,12 @@ use rayon::prelude::*; use crate::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; use crate::field::extension_field::Extendable; use crate::generator::generate_partial_witness; +use crate::hash::hash_n_to_hash; use crate::plonk_challenger::Challenger; use crate::plonk_common::{PlonkPolynomials, ZeroPolyOnCoset}; use crate::polynomial::commitment::ListPolynomialCommitment; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; -use crate::proof::Proof; +use crate::proof::{Hash, Proof, ProofWithPublicInputs}; use crate::timed; use crate::util::partial_products::partial_products; use crate::util::{log2_ceil, transpose}; @@ -23,7 +24,7 @@ pub(crate) fn prove, const D: usize>( prover_data: &ProverOnlyCircuitData, common_data: &CommonCircuitData, inputs: PartialWitness, -) -> Result> { +) -> Result> { let config = &common_data.config; let num_wires = config.num_wires; let num_challenges = config.num_challenges; @@ -39,6 +40,9 @@ pub(crate) fn prove, const D: usize>( "to generate witness" ); + let public_inputs = partial_witness.get_targets(&prover_data.public_inputs); + let public_inputs_hash = hash_n_to_hash(public_inputs.clone(), true); + // Display the marked targets for debugging purposes. for m in &prover_data.marked_targets { m.display(&partial_witness); @@ -58,7 +62,7 @@ pub(crate) fn prove, const D: usize>( let wires_values: Vec> = timed!( witness .wire_values - .iter() + .par_iter() .map(|column| PolynomialValues::new(column.clone())) .collect(), "to compute wire polynomials" @@ -119,6 +123,7 @@ pub(crate) fn prove, const D: usize>( compute_quotient_polys( common_data, prover_data, + &public_inputs_hash, &wires_commitment, &zs_partial_products_commitment, &betas, @@ -178,12 +183,16 @@ pub(crate) fn prove, const D: usize>( start_proof_gen.elapsed().as_secs_f32() ); - Ok(Proof { + let proof = Proof { wires_root: wires_commitment.merkle_tree.root, plonk_zs_partial_products_root: zs_partial_products_commitment.merkle_tree.root, quotient_polys_root: quotient_polys_commitment.merkle_tree.root, openings, opening_proof, + }; + Ok(ProofWithPublicInputs { + proof, + public_inputs, }) } @@ -284,6 +293,7 @@ fn compute_z, const D: usize>( fn compute_quotient_polys<'a, F: Extendable, const D: usize>( common_data: &CommonCircuitData, prover_data: &'a ProverOnlyCircuitData, + public_inputs_hash: &Hash, wires_commitment: &'a ListPolynomialCommitment, zs_partial_products_commitment: &'a ListPolynomialCommitment, betas: &[F], @@ -337,6 +347,7 @@ fn compute_quotient_polys<'a, F: Extendable, const D: usize>( let vars = EvaluationVarsBase { local_constants, local_wires, + public_inputs_hash, }; let mut quotient_values = eval_vanishing_poly_base( common_data, diff --git a/src/recursive_verifier.rs b/src/recursive_verifier.rs index 51bfe6ba..66b63e7f 100644 --- a/src/recursive_verifier.rs +++ b/src/recursive_verifier.rs @@ -3,7 +3,7 @@ use crate::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarge use crate::context; use crate::field::extension_field::Extendable; use crate::plonk_challenger::RecursiveChallenger; -use crate::proof::{HashTarget, ProofTarget}; +use crate::proof::{HashTarget, ProofWithPublicInputsTarget}; use crate::util::scaling::ReducingFactorTarget; use crate::vanishing_poly::eval_vanishing_poly_recursively; use crate::vars::EvaluationTargets; @@ -15,15 +15,21 @@ impl, const D: usize> CircuitBuilder { /// Recursively verifies an inner proof. pub fn add_recursive_verifier( &mut self, - proof: ProofTarget, + proof_with_pis: ProofWithPublicInputsTarget, inner_config: &CircuitConfig, inner_verifier_data: &VerifierCircuitTarget, inner_common_data: &CommonCircuitData, ) { assert!(self.config.num_wires >= MIN_WIRES); assert!(self.config.num_wires >= MIN_ROUTED_WIRES); + let ProofWithPublicInputsTarget { + proof, + public_inputs, + } = proof_with_pis; let one = self.one_extension(); + let public_inputs_hash = &self.hash_n_to_hash(public_inputs, true); + let num_challenges = inner_config.num_challenges; let mut challenger = RecursiveChallenger::new(self); @@ -53,13 +59,14 @@ impl, const D: usize> CircuitBuilder { let vars = EvaluationTargets { local_constants, local_wires, + public_inputs_hash, }; let local_zs = &proof.openings.plonk_zs; let next_zs = &proof.openings.plonk_zs_right; let s_sigmas = &proof.openings.plonk_sigmas; let partial_products = &proof.openings.partial_products; - let zeta_pow_deg = self.exp_power_of_2(zeta, inner_common_data.degree_bits); + let zeta_pow_deg = self.exp_power_of_2_extension(zeta, inner_common_data.degree_bits); let vanishing_polys_zeta = context!( self, "evaluate the vanishing polynomial at our challenge point, zeta.", @@ -89,7 +96,7 @@ impl, const D: usize> CircuitBuilder { { let recombined_quotient = scale.reduce(chunk, self); let computed_vanishing_poly = self.mul_extension(z_h_zeta, recombined_quotient); - self.named_route_extension( + self.named_assert_equal_extension( vanishing_polys_zeta[i], computed_vanishing_poly, format!("Vanishing polynomial == Z_H * quotient, challenge {}", i), @@ -127,7 +134,7 @@ mod tests { use crate::polynomial::commitment::OpeningProofTarget; use crate::proof::{ FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget, FriQueryStepTarget, - OpeningSetTarget, Proof, + OpeningSetTarget, Proof, ProofTarget, ProofWithPublicInputs, }; use crate::verifier::verify; use crate::witness::PartialWitness; @@ -167,9 +174,14 @@ mod tests { // Construct a `ProofTarget` with the same dimensions as `proof`. fn proof_to_proof_target, const D: usize>( - proof: &Proof, + proof_with_pis: &ProofWithPublicInputs, builder: &mut CircuitBuilder, - ) -> ProofTarget { + ) -> ProofWithPublicInputsTarget { + let ProofWithPublicInputs { + proof, + public_inputs, + } = proof_with_pis; + let wires_root = builder.add_virtual_hash(); let plonk_zs_root = builder.add_virtual_hash(); let quotient_polys_root = builder.add_virtual_hash(); @@ -208,21 +220,41 @@ mod tests { }, }; - ProofTarget { + let proof = ProofTarget { wires_root, plonk_zs_partial_products_root: plonk_zs_root, quotient_polys_root, openings, opening_proof, + }; + + let public_inputs = builder.add_virtual_targets(public_inputs.len()); + ProofWithPublicInputsTarget { + proof, + public_inputs, } } // Set the targets in a `ProofTarget` to their corresponding values in a `Proof`. fn set_proof_target, const D: usize>( - proof: &Proof, - pt: &ProofTarget, + proof: &ProofWithPublicInputs, + pt: &ProofWithPublicInputsTarget, pw: &mut PartialWitness, ) { + let ProofWithPublicInputs { + proof, + public_inputs, + } = proof; + let ProofWithPublicInputsTarget { + proof: pt, + public_inputs: pi_targets, + } = pt; + + // Set public inputs. + for (&pi_t, &pi) in pi_targets.iter().zip(public_inputs) { + pw.set_target(pi_t, pi); + } + pw.set_hash_target(pt.wires_root, proof.wires_root); pw.set_hash_target( pt.plonk_zs_partial_products_root, @@ -343,7 +375,7 @@ mod tests { num_query_rounds: 40, }, }; - let (proof, vd, cd) = { + let (proof_with_pis, vd, cd) = { let mut builder = CircuitBuilder::::new(config.clone()); let _two = builder.two(); let _two = builder.hash_n_to_hash(vec![_two], true).elements[0]; @@ -357,12 +389,12 @@ mod tests { data.common, ) }; - verify(proof.clone(), &vd, &cd)?; + verify(proof_with_pis.clone(), &vd, &cd)?; let mut builder = CircuitBuilder::::new(config.clone()); let mut pw = PartialWitness::new(); - let pt = proof_to_proof_target(&proof, &mut builder); - set_proof_target(&proof, &pt, &mut pw); + let pt = proof_to_proof_target(&proof_with_pis, &mut builder); + set_proof_target(&proof_with_pis, &pt, &mut pw); let inner_data = VerifierCircuitTarget { constants_sigmas_root: builder.add_virtual_hash(), diff --git a/src/target.rs b/src/target.rs index 52be8b5a..857a4378 100644 --- a/src/target.rs +++ b/src/target.rs @@ -7,9 +7,6 @@ use crate::wire::Wire; #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub enum Target { Wire(Wire), - PublicInput { - index: usize, - }, /// A target that doesn't have any inherent location in the witness (but it can be copied to /// another target that does). This is useful for representing intermediate values in witness /// generation. @@ -26,7 +23,6 @@ impl Target { pub fn is_routable(&self, config: &CircuitConfig) -> bool { match self { Target::Wire(wire) => wire.is_routable(config), - Target::PublicInput { .. } => true, Target::VirtualTarget { .. } => true, } } diff --git a/src/util/mod.rs b/src/util/mod.rs index bfebe058..83a97881 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -51,7 +51,7 @@ pub(crate) fn transpose(matrix: &[Vec]) -> Vec> { } /// Permutes `arr` such that each index is mapped to its reverse in binary. -pub(crate) fn reverse_index_bits(arr: Vec) -> Vec { +pub(crate) fn reverse_index_bits(arr: &[T]) -> Vec { let n = arr.len(); let n_power = log2_strict(n); @@ -99,12 +99,9 @@ mod tests { #[test] fn test_reverse_index_bits() { + assert_eq!(reverse_index_bits(&[10, 20, 30, 40]), vec![10, 30, 20, 40]); assert_eq!( - reverse_index_bits(vec![10, 20, 30, 40]), - vec![10, 30, 20, 40] - ); - assert_eq!( - reverse_index_bits(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + reverse_index_bits(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), vec![0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15] ); } diff --git a/src/util/scaling.rs b/src/util/scaling.rs index af32201d..5cbffb83 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -16,7 +16,7 @@ use crate::polynomial::polynomial::PolynomialCoeffs; /// This struct abstract away these operations by implementing Horner's method and keeping track /// of the number of multiplications by `a` to compute the scaling factor. /// See https://github.com/mir-protocol/plonky2/pull/69 for more details and discussions. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ReducingFactor { base: F, count: u64, @@ -79,7 +79,7 @@ impl ReducingFactor { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ReducingFactorTarget { base: ExtensionTarget, count: u64, diff --git a/src/vanishing_poly.rs b/src/vanishing_poly.rs index b82a23c1..540ab86b 100644 --- a/src/vanishing_poly.rs +++ b/src/vanishing_poly.rs @@ -236,10 +236,13 @@ pub fn evaluate_gate_constraints_recursively, const D: usize>( ) -> Vec> { let mut constraints = vec![builder.zero_extension(); num_gate_constraints]; for gate in gates { - let gate_constraints = gate - .gate - .0 - .eval_filtered_recursively(builder, vars, &gate.prefix); + let gate_constraints = context!( + builder, + &format!("evaluate {} constraints", gate.gate.0.id()), + gate.gate + .0 + .eval_filtered_recursively(builder, vars, &gate.prefix) + ); for (i, c) in gate_constraints.into_iter().enumerate() { constraints[i] = builder.add_extension(constraints[i], c); } diff --git a/src/vars.rs b/src/vars.rs index 17d51051..8e98d41f 100644 --- a/src/vars.rs +++ b/src/vars.rs @@ -5,17 +5,20 @@ use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::Extendable; use crate::field::field::Field; +use crate::proof::{Hash, HashTarget}; #[derive(Debug, Copy, Clone)] pub struct EvaluationVars<'a, F: Extendable, const D: usize> { pub(crate) local_constants: &'a [F::Extension], pub(crate) local_wires: &'a [F::Extension], + pub(crate) public_inputs_hash: &'a Hash, } #[derive(Debug, Copy, Clone)] pub struct EvaluationVarsBase<'a, F: Field> { pub(crate) local_constants: &'a [F], pub(crate) local_wires: &'a [F], + pub(crate) public_inputs_hash: &'a Hash, } impl<'a, F: Extendable, const D: usize> EvaluationVars<'a, F, D> { @@ -49,6 +52,7 @@ impl<'a, const D: usize> EvaluationTargets<'a, D> { pub struct EvaluationTargets<'a, const D: usize> { pub(crate) local_constants: &'a [ExtensionTarget], pub(crate) local_wires: &'a [ExtensionTarget], + pub(crate) public_inputs_hash: &'a HashTarget, } impl<'a, const D: usize> EvaluationTargets<'a, D> { diff --git a/src/verifier.rs b/src/verifier.rs index 6b4df627..878b630a 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -3,20 +3,27 @@ use anyhow::{ensure, Result}; use crate::circuit_data::{CommonCircuitData, VerifierOnlyCircuitData}; use crate::field::extension_field::Extendable; use crate::field::field::Field; +use crate::hash::hash_n_to_hash; use crate::plonk_challenger::Challenger; use crate::plonk_common::reduce_with_powers; -use crate::proof::Proof; +use crate::proof::ProofWithPublicInputs; use crate::vanishing_poly::eval_vanishing_poly; use crate::vars::EvaluationVars; pub(crate) fn verify, const D: usize>( - proof: Proof, + proof_with_pis: ProofWithPublicInputs, verifier_data: &VerifierOnlyCircuitData, common_data: &CommonCircuitData, ) -> Result<()> { + let ProofWithPublicInputs { + proof, + public_inputs, + } = proof_with_pis; let config = &common_data.config; let num_challenges = config.num_challenges; + let public_inputs_hash = &hash_n_to_hash(public_inputs, true); + let mut challenger = Challenger::new(); // Observe the instance. // TODO: Need to include public inputs as well. @@ -37,6 +44,7 @@ pub(crate) fn verify, const D: usize>( let vars = EvaluationVars { local_constants, local_wires, + public_inputs_hash, }; let local_zs = &proof.openings.plonk_zs; let next_zs = &proof.openings.plonk_zs_right; diff --git a/src/witness.rs b/src/witness.rs index 25885ba3..ce4a95af 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -8,6 +8,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::gates::gate::GateInstance; +use crate::generator::GeneratedValues; use crate::proof::{Hash, HashTarget}; use crate::target::Target; use crate::wire::Wire; @@ -35,28 +36,6 @@ impl PartialWitness { } } - pub fn singleton_wire(wire: Wire, value: F) -> Self { - Self::singleton_target(Target::Wire(wire), value) - } - - pub fn singleton_target(target: Target, value: F) -> Self { - let mut witness = PartialWitness::new(); - witness.set_target(target, value); - witness - } - - pub fn singleton_extension_target( - et: ExtensionTarget, - value: F::Extension, - ) -> Self - where - F: Extendable, - { - let mut witness = PartialWitness::new(); - witness.set_extension_target(et, value); - witness - } - pub fn is_empty(&self) -> bool { self.target_values.is_empty() } @@ -157,7 +136,7 @@ impl PartialWitness { self.set_wires(wires, &value.to_basefield_array()); } - pub fn extend(&mut self, other: PartialWitness) { + pub fn extend(&mut self, other: GeneratedValues) { for (target, value) in other.target_values { self.set_target(target, value); } @@ -193,7 +172,6 @@ impl PartialWitness { gate, gate_instances[*gate].gate_type.0.id() ), - Target::PublicInput { index } => format!("{}-th public input", index), Target::VirtualTarget { index } => format!("{}-th virtual target", index), } };