diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index c7951d78..a4da2c95 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -7,7 +7,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::{Extendable, OEF}; use crate::gates::arithmetic::ArithmeticExtensionGate; -use crate::generator::SimpleGenerator; +use crate::generator::{GeneratedValues, SimpleGenerator}; use crate::target::Target; use crate::util::bits_u64; use crate::witness::PartialWitness; @@ -384,11 +384,11 @@ impl, const D: usize> SimpleGenerator for QuotientGeneratorE deps } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let num = witness.get_extension_target(self.numerator); let dem = witness.get_extension_target(self.denominator); let quotient = num / dem; - PartialWitness::singleton_extension_target(self.quotient, quotient) + GeneratedValues::singleton_extension_target(self.quotient, quotient) } } diff --git a/src/gadgets/range_check.rs b/src/gadgets/range_check.rs index 7fd35efc..c0848af8 100644 --- a/src/gadgets/range_check.rs +++ b/src/gadgets/range_check.rs @@ -2,7 +2,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::base_sum::BaseSumGate; -use crate::generator::SimpleGenerator; +use crate::generator::{GeneratedValues, SimpleGenerator}; use crate::target::Target; use crate::witness::PartialWitness; @@ -49,12 +49,12 @@ impl SimpleGenerator for LowHighGenerator { vec![self.integer] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let integer_value = witness.get_target(self.integer).to_canonical_u64(); let low = integer_value & ((1 << self.n_log) - 1); let high = integer_value >> self.n_log; - let mut result = PartialWitness::new(); + let mut result = GeneratedValues::with_capacity(2); result.set_target(self.low, F::from_canonical_u64(low)); result.set_target(self.high, F::from_canonical_u64(high)); diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 3a2c27f4..9cc6ab7c 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -2,7 +2,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::base_sum::BaseSumGate; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::util::ceil_div_usize; use crate::wire::Wire; @@ -110,10 +110,10 @@ impl SimpleGenerator for SplitGenerator { vec![self.integer] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let mut integer_value = witness.get_target(self.integer).to_canonical_u64(); - let mut result = PartialWitness::new(); + let mut result = GeneratedValues::with_capacity(self.bits.len()); for &b in &self.bits { let b_value = integer_value & 1; result.set_target(b, F::from_canonical_u64(b_value)); @@ -141,10 +141,10 @@ impl SimpleGenerator for WireSplitGenerator { vec![self.integer] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let mut integer_value = witness.get_target(self.integer).to_canonical_u64(); - let mut result = PartialWitness::new(); + let mut result = GeneratedValues::with_capacity(self.gates.len()); for &gate in &self.gates { let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); result.set_target( diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 39baa226..73865f07 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -4,7 +4,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::witness::PartialWitness; @@ -157,7 +157,7 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio .collect() } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let extract_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.gate_index, range); witness.get_extension_target(t) @@ -177,7 +177,7 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into() + addend_0 * self.const_1.into(); - PartialWitness::singleton_extension_target(output_target_0, computed_output_0) + GeneratedValues::singleton_extension_target(output_target_0, computed_output_0) } } @@ -190,7 +190,7 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio .collect() } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let extract_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.gate_index, range); witness.get_extension_target(t) @@ -210,7 +210,7 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into() + addend_1 * self.const_1.into(); - PartialWitness::singleton_extension_target(output_target_1, computed_output_1) + GeneratedValues::singleton_extension_target(output_target_1, computed_output_1) } } diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index 8f453d8e..8ad189ee 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -5,7 +5,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::plonk_common::{reduce_with_powers, reduce_with_powers_recursive}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; @@ -130,7 +130,7 @@ impl SimpleGenerator for BaseSplitGenerator { vec![Target::wire(self.gate_index, BaseSumGate::::WIRE_SUM)] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let sum_value = witness .get_target(Target::wire(self.gate_index, BaseSumGate::::WIRE_SUM)) .to_canonical_u64() as usize; @@ -155,7 +155,7 @@ impl SimpleGenerator for BaseSplitGenerator { .iter() .fold(F::ZERO, |acc, &x| acc * b_field + x); - let mut result = PartialWitness::new(); + let mut result = GeneratedValues::with_capacity(self.num_limbs + 1); result.set_target( Target::wire(self.gate_index, BaseSumGate::::WIRE_REVERSED_SUM), reversed_sum, diff --git a/src/gates/constant.rs b/src/gates/constant.rs index 3845031a..4049d058 100644 --- a/src/gates/constant.rs +++ b/src/gates/constant.rs @@ -3,7 +3,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; @@ -83,12 +83,12 @@ impl SimpleGenerator for ConstantGenerator { Vec::new() } - fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, _witness: &PartialWitness) -> GeneratedValues { let wire = Wire { gate: self.gate_index, input: ConstantGate::WIRE_OUTPUT, }; - PartialWitness::singleton_target(Target::Wire(wire), self.constant) + GeneratedValues::singleton_target(Target::Wire(wire), self.constant) } } diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index e61ed8d9..0404884b 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -5,7 +5,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::gmimc::gmimc_automatic_constants; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; @@ -239,8 +239,8 @@ impl, const D: usize, const R: usize> SimpleGenerator .collect() } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let mut result = PartialWitness::new(); + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + let mut result = GeneratedValues::with_capacity(R + W + 1); let mut state = (0..W) .map(|i| { diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index 2fbeda15..1bc0b454 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -7,7 +7,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; @@ -218,7 +218,7 @@ impl, const D: usize> SimpleGenerator for InsertionGenerator deps } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let local_wire = |input| Wire { gate: self.gate_index, input, @@ -264,7 +264,7 @@ impl, const D: usize> SimpleGenerator for InsertionGenerator let mut insert_here_vals = vec![F::ZERO; vec_size]; insert_here_vals.insert(insertion_index, F::ONE); - let mut result = PartialWitness::::new(); + let mut result = GeneratedValues::::with_capacity((vec_size + 1) * (D + 2)); for i in 0..=vec_size { let output_wires = self.gate.wires_output_list_item(i).map(local_wire); result.set_ext_wires(output_wires, new_vec[i]); diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 2d6745b4..17d34e3a 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -9,7 +9,7 @@ use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::interpolation::interpolant; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; @@ -216,7 +216,7 @@ impl, const D: usize> SimpleGenerator for InterpolationGener deps } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let n = self.gate.num_points; let local_wire = |input| Wire { @@ -244,7 +244,7 @@ impl, const D: usize> SimpleGenerator for InterpolationGener .collect::>(); let interpolant = interpolant(&points); - let mut result = PartialWitness::::new(); + let mut result = GeneratedValues::::with_capacity(D * (self.gate.num_points + 1)); for (i, &coeff) in interpolant.coeffs.iter().enumerate() { let wires = self.gate.wires_coeff(i).map(local_wire); result.set_ext_wires(wires, coeff); diff --git a/src/generator.rs b/src/generator.rs index a47c5267..a7359a7d 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -1,8 +1,13 @@ use std::collections::{HashMap, HashSet}; +use std::convert::identity; use std::fmt::Debug; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; +use crate::proof::{Hash, HashTarget}; use crate::target::Target; +use crate::wire::Wire; use crate::witness::PartialWitness; /// Given a `PartialWitness` that has only inputs set, populates the rest of the witness using the @@ -27,7 +32,7 @@ pub(crate) fn generate_partial_witness( let mut pending_generator_indices: HashSet<_> = (0..generators.len()).collect(); // We also track a list of "expired" generators which have already returned false. - let mut expired_generator_indices = HashSet::new(); + let mut generator_is_expired = vec![false; generators.len()]; // Keep running generators until no generators are queued. while !pending_generator_indices.is_empty() { @@ -36,15 +41,15 @@ pub(crate) fn generate_partial_witness( for &generator_idx in &pending_generator_indices { let (result, finished) = generators[generator_idx].run(&witness); if finished { - expired_generator_indices.insert(generator_idx); + generator_is_expired[generator_idx] = true; } // Enqueue unfinished generators that were watching one of the newly populated targets. - for watch in result.target_values.keys() { + for (watch, _) in &result.target_values { if let Some(watching_generator_indices) = generator_indices_by_watches.get(watch) { - for watching_generator_idx in watching_generator_indices { - if !expired_generator_indices.contains(watching_generator_idx) { - next_pending_generator_indices.insert(*watching_generator_idx); + for &watching_generator_idx in watching_generator_indices { + if !generator_is_expired[watching_generator_idx] { + next_pending_generator_indices.insert(watching_generator_idx); } } } @@ -55,9 +60,9 @@ pub(crate) fn generate_partial_witness( pending_generator_indices = next_pending_generator_indices; } - assert_eq!( - expired_generator_indices.len(), - generators.len(), + + assert!( + generator_is_expired.into_iter().all(identity), "Some generators weren't run." ); } @@ -72,14 +77,101 @@ pub trait WitnessGenerator: 'static + Send + Sync { /// flag indicating whether the generator is finished. If the flag is true, the generator will /// never be run again, otherwise it will be queued for another run next time a target in its /// watch list is populated. - fn run(&self, witness: &PartialWitness) -> (PartialWitness, bool); + fn run(&self, witness: &PartialWitness) -> (GeneratedValues, bool); +} + +/// Values generated by a generator invocation. +pub struct GeneratedValues { + pub(crate) target_values: Vec<(Target, F)>, +} + +impl From> for GeneratedValues { + fn from(target_values: Vec<(Target, F)>) -> Self { + Self { target_values } + } +} + +impl GeneratedValues { + pub fn with_capacity(capacity: usize) -> Self { + Vec::with_capacity(capacity).into() + } + + pub fn empty() -> Self { + Vec::new().into() + } + + pub fn singleton_wire(wire: Wire, value: F) -> Self { + Self::singleton_target(Target::Wire(wire), value) + } + + pub fn singleton_target(target: Target, value: F) -> Self { + vec![(target, value)].into() + } + + pub fn singleton_extension_target( + et: ExtensionTarget, + value: F::Extension, + ) -> Self + where + F: Extendable, + { + let mut witness = Self::with_capacity(D); + witness.set_extension_target(et, value); + witness + } + + pub fn set_target(&mut self, target: Target, value: F) { + self.target_values.push((target, value)) + } + + pub fn set_hash_target(&mut self, ht: HashTarget, value: Hash) { + ht.elements + .iter() + .zip(value.elements) + .for_each(|(&t, x)| self.set_target(t, x)); + } + + pub fn set_extension_target( + &mut self, + et: ExtensionTarget, + value: F::Extension, + ) where + F: Extendable, + { + let limbs = value.to_basefield_array(); + (0..D).for_each(|i| { + self.set_target(et.0[i], limbs[i]); + }); + } + + pub fn set_wire(&mut self, wire: Wire, value: F) { + self.set_target(Target::Wire(wire), value) + } + + pub fn set_wires(&mut self, wires: W, values: &[F]) + where + W: IntoIterator, + { + // If we used itertools, we could use zip_eq for extra safety. + for (wire, &value) in wires.into_iter().zip(values) { + self.set_wire(wire, value); + } + } + + pub fn set_ext_wires(&mut self, wires: W, value: F::Extension) + where + F: Extendable, + W: IntoIterator, + { + self.set_wires(wires, &value.to_basefield_array()); + } } /// A generator which runs once after a list of dependencies is present in the witness. pub trait SimpleGenerator: 'static + Send + Sync { fn dependencies(&self) -> Vec; - fn run_once(&self, witness: &PartialWitness) -> PartialWitness; + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues; } impl> WitnessGenerator for SG { @@ -87,11 +179,11 @@ impl> WitnessGenerator for SG { self.dependencies() } - fn run(&self, witness: &PartialWitness) -> (PartialWitness, bool) { + fn run(&self, witness: &PartialWitness) -> (GeneratedValues, bool) { if witness.contains_all(&self.dependencies()) { (self.run_once(witness), true) } else { - (PartialWitness::new(), false) + (GeneratedValues::empty(), false) } } } @@ -108,9 +200,9 @@ impl SimpleGenerator for CopyGenerator { vec![self.src] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let value = witness.get_target(self.src); - PartialWitness::singleton_target(self.dst, value) + GeneratedValues::singleton_target(self.dst, value) } } @@ -124,10 +216,10 @@ impl SimpleGenerator for RandomValueGenerator { Vec::new() } - fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, _witness: &PartialWitness) -> GeneratedValues { let random_value = F::rand(); - PartialWitness::singleton_target(self.target, random_value) + GeneratedValues::singleton_target(self.target, random_value) } } @@ -142,7 +234,7 @@ impl SimpleGenerator for NonzeroTestGenerator { vec![self.to_test] } - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { let to_test_value = witness.get_target(self.to_test); let dummy_value = if to_test_value == F::ZERO { @@ -151,6 +243,6 @@ impl SimpleGenerator for NonzeroTestGenerator { to_test_value.inverse() }; - PartialWitness::singleton_target(self.dummy, dummy_value) + GeneratedValues::singleton_target(self.dummy, dummy_value) } } diff --git a/src/prover.rs b/src/prover.rs index 4c2e86bc..f1cef856 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -62,7 +62,7 @@ pub(crate) fn prove, const D: usize>( let wires_values: Vec> = timed!( witness .wire_values - .iter() + .par_iter() .map(|column| PolynomialValues::new(column.clone())) .collect(), "to compute wire polynomials" diff --git a/src/witness.rs b/src/witness.rs index cdeafcce..ce4a95af 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -8,6 +8,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::gates::gate::GateInstance; +use crate::generator::GeneratedValues; use crate::proof::{Hash, HashTarget}; use crate::target::Target; use crate::wire::Wire; @@ -35,28 +36,6 @@ impl PartialWitness { } } - pub fn singleton_wire(wire: Wire, value: F) -> Self { - Self::singleton_target(Target::Wire(wire), value) - } - - pub fn singleton_target(target: Target, value: F) -> Self { - let mut witness = PartialWitness::new(); - witness.set_target(target, value); - witness - } - - pub fn singleton_extension_target( - et: ExtensionTarget, - value: F::Extension, - ) -> Self - where - F: Extendable, - { - let mut witness = PartialWitness::new(); - witness.set_extension_target(et, value); - witness - } - pub fn is_empty(&self) -> bool { self.target_values.is_empty() } @@ -157,7 +136,7 @@ impl PartialWitness { self.set_wires(wires, &value.to_basefield_array()); } - pub fn extend(&mut self, other: PartialWitness) { + pub fn extend(&mut self, other: GeneratedValues) { for (target, value) in other.target_values { self.set_target(target, value); }