Optimize witness generation a bit (#153)

Mainly storing pending generators in a Vec rather than a HashMap.  Requires an extra check to make sure we don't run one twice after adding it to the Vec twice.
This commit is contained in:
Daniel Lubarov 2021-08-04 09:55:11 -07:00 committed by GitHub
parent 79af87535a
commit 8b8e4d223d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 30 additions and 19 deletions

View File

@ -334,6 +334,7 @@ mod tests {
use crate::iop::generator::generate_partial_witness;
use crate::iop::wire::Wire;
use crate::iop::witness::PartialWitness;
use crate::util::timing::TimingTree;
#[test]
fn generated_output() {
@ -364,7 +365,7 @@ mod tests {
}
let generators = gate.generators(0, &[]);
generate_partial_witness(&mut witness, &generators);
generate_partial_witness(&mut witness, &generators, &mut TimingTree::default());
let expected_outputs: [F; W] =
gmimc_permute_naive(permutation_inputs.try_into().unwrap(), constants);

View File

@ -350,6 +350,7 @@ mod tests {
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::util::timing::TimingTree;
#[test]
fn no_duplicate_challenges() {
@ -409,7 +410,11 @@ mod tests {
}
let circuit = builder.build();
let mut witness = PartialWitness::new();
generate_partial_witness(&mut witness, &circuit.prover_only.generators);
generate_partial_witness(
&mut witness,
&circuit.prover_only.generators,
&mut TimingTree::default(),
);
let recursive_output_values_per_round: Vec<Vec<F>> = recursive_outputs_per_round
.iter()
.map(|outputs| witness.get_targets(outputs))

View File

@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::convert::identity;
use std::fmt::Debug;
@ -9,27 +9,32 @@ use crate::hash::hash_types::{HashOut, HashOutTarget};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::PartialWitness;
use crate::timed;
use crate::util::timing::TimingTree;
/// Given a `PartialWitness` that has only inputs set, populates the rest of the witness using the
/// given set of generators.
pub(crate) fn generate_partial_witness<F: Field>(
witness: &mut PartialWitness<F>,
generators: &[Box<dyn WitnessGenerator<F>>],
timing: &mut TimingTree,
) {
// Index generator indices by their watched targets.
let mut generator_indices_by_watches = HashMap::new();
for (i, generator) in generators.iter().enumerate() {
for watch in generator.watch_list() {
generator_indices_by_watches
.entry(watch)
.or_insert_with(Vec::new)
.push(i);
timed!(timing, "index generators by their watched targets", {
for (i, generator) in generators.iter().enumerate() {
for watch in generator.watch_list() {
generator_indices_by_watches
.entry(watch)
.or_insert_with(Vec::new)
.push(i);
}
}
}
});
// Build a list of "pending" generators which are queued to be run. Initially, all generators
// are queued.
let mut pending_generator_indices: HashSet<_> = (0..generators.len()).collect();
let mut pending_generator_indices: Vec<_> = (0..generators.len()).collect();
// We also track a list of "expired" generators which have already returned false.
let mut generator_is_expired = vec![false; generators.len()];
@ -38,9 +43,13 @@ pub(crate) fn generate_partial_witness<F: Field>(
// Keep running generators until no generators are queued.
while !pending_generator_indices.is_empty() {
let mut next_pending_generator_indices = HashSet::new();
let mut next_pending_generator_indices = Vec::new();
for &generator_idx in &pending_generator_indices {
if generator_is_expired[generator_idx] {
continue;
}
let finished = generators[generator_idx].run(&witness, &mut buffer);
if finished {
generator_is_expired[generator_idx] = true;
@ -50,9 +59,7 @@ pub(crate) fn generate_partial_witness<F: Field>(
for (watch, _) in &buffer.target_values {
if let Some(watching_generator_indices) = generator_indices_by_watches.get(watch) {
for &watching_generator_idx in watching_generator_indices {
if !generator_is_expired[watching_generator_idx] {
next_pending_generator_indices.insert(watching_generator_idx);
}
next_pending_generator_indices.push(watching_generator_idx);
}
}
}

View File

@ -162,9 +162,7 @@ impl<F: Field> PartialWitness<F> {
}
pub fn extend<I: Iterator<Item = (Target, F)>>(&mut self, pairs: I) {
for (target, value) in pairs {
self.set_target(target, value);
}
self.target_values.extend(pairs);
}
pub fn full_witness(self, degree: usize, num_wires: usize) -> Witness<F> {

View File

@ -37,7 +37,7 @@ pub(crate) fn prove<F: Extendable<D>, const D: usize>(
timed!(
timing,
&format!("run {} generators", prover_data.generators.len()),
generate_partial_witness(&mut partial_witness, &prover_data.generators)
generate_partial_witness(&mut partial_witness, &prover_data.generators, &mut timing)
);
let public_inputs = partial_witness.get_targets(&prover_data.public_inputs);