From 8b8e4d223d961255f75904f71aa69c4422dd43d2 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 4 Aug 2021 09:55:11 -0700 Subject: [PATCH] Optimize witness generation a bit (#153) Mainly storing pending generators in a Vec rather than a HashMap. Requires an extra check to make sure we don't run one twice after adding it to the Vec twice. --- src/gates/gmimc.rs | 3 ++- src/iop/challenger.rs | 7 ++++++- src/iop/generator.rs | 33 ++++++++++++++++++++------------- src/iop/witness.rs | 4 +--- src/plonk/prover.rs | 2 +- 5 files changed, 30 insertions(+), 19 deletions(-) diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 4a9425af..97847d41 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -334,6 +334,7 @@ mod tests { use crate::iop::generator::generate_partial_witness; use crate::iop::wire::Wire; use crate::iop::witness::PartialWitness; + use crate::util::timing::TimingTree; #[test] fn generated_output() { @@ -364,7 +365,7 @@ mod tests { } let generators = gate.generators(0, &[]); - generate_partial_witness(&mut witness, &generators); + generate_partial_witness(&mut witness, &generators, &mut TimingTree::default()); let expected_outputs: [F; W] = gmimc_permute_naive(permutation_inputs.try_into().unwrap(), constants); diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index d15eab30..53867de7 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -350,6 +350,7 @@ mod tests { use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + use crate::util::timing::TimingTree; #[test] fn no_duplicate_challenges() { @@ -409,7 +410,11 @@ mod tests { } let circuit = builder.build(); let mut witness = PartialWitness::new(); - generate_partial_witness(&mut witness, &circuit.prover_only.generators); + generate_partial_witness( + &mut witness, + &circuit.prover_only.generators, + &mut TimingTree::default(), + ); let recursive_output_values_per_round: Vec> = recursive_outputs_per_round .iter() .map(|outputs| witness.get_targets(outputs)) diff --git a/src/iop/generator.rs b/src/iop/generator.rs index 9f421b14..ad492810 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::convert::identity; use std::fmt::Debug; @@ -9,27 +9,32 @@ use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::PartialWitness; +use crate::timed; +use crate::util::timing::TimingTree; /// Given a `PartialWitness` that has only inputs set, populates the rest of the witness using the /// given set of generators. pub(crate) fn generate_partial_witness( witness: &mut PartialWitness, generators: &[Box>], + timing: &mut TimingTree, ) { // Index generator indices by their watched targets. let mut generator_indices_by_watches = HashMap::new(); - for (i, generator) in generators.iter().enumerate() { - for watch in generator.watch_list() { - generator_indices_by_watches - .entry(watch) - .or_insert_with(Vec::new) - .push(i); + timed!(timing, "index generators by their watched targets", { + for (i, generator) in generators.iter().enumerate() { + for watch in generator.watch_list() { + generator_indices_by_watches + .entry(watch) + .or_insert_with(Vec::new) + .push(i); + } } - } + }); // Build a list of "pending" generators which are queued to be run. Initially, all generators // are queued. - let mut pending_generator_indices: HashSet<_> = (0..generators.len()).collect(); + let mut pending_generator_indices: Vec<_> = (0..generators.len()).collect(); // We also track a list of "expired" generators which have already returned false. let mut generator_is_expired = vec![false; generators.len()]; @@ -38,9 +43,13 @@ pub(crate) fn generate_partial_witness( // Keep running generators until no generators are queued. while !pending_generator_indices.is_empty() { - let mut next_pending_generator_indices = HashSet::new(); + let mut next_pending_generator_indices = Vec::new(); for &generator_idx in &pending_generator_indices { + if generator_is_expired[generator_idx] { + continue; + } + let finished = generators[generator_idx].run(&witness, &mut buffer); if finished { generator_is_expired[generator_idx] = true; @@ -50,9 +59,7 @@ pub(crate) fn generate_partial_witness( for (watch, _) in &buffer.target_values { if let Some(watching_generator_indices) = generator_indices_by_watches.get(watch) { for &watching_generator_idx in watching_generator_indices { - if !generator_is_expired[watching_generator_idx] { - next_pending_generator_indices.insert(watching_generator_idx); - } + next_pending_generator_indices.push(watching_generator_idx); } } } diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 1c2f2e3d..5b97f546 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -162,9 +162,7 @@ impl PartialWitness { } pub fn extend>(&mut self, pairs: I) { - for (target, value) in pairs { - self.set_target(target, value); - } + self.target_values.extend(pairs); } pub fn full_witness(self, degree: usize, num_wires: usize) -> Witness { diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index a8520dde..288f4c29 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -37,7 +37,7 @@ pub(crate) fn prove, const D: usize>( timed!( timing, &format!("run {} generators", prover_data.generators.len()), - generate_partial_witness(&mut partial_witness, &prover_data.generators) + generate_partial_witness(&mut partial_witness, &prover_data.generators, &mut timing) ); let public_inputs = partial_witness.get_targets(&prover_data.public_inputs);