From b3008b94756e908c27596f719c035eab0dc02919 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 13 Sep 2021 16:38:55 -0700 Subject: [PATCH] Some changes to generator_indices_by_watches (#234) * Some changes to generator_indices_by_watches - Index generators by the representatives (in disjoint-set forest terminology) of their watched targets, rather than the watched targets themselves. Enqueuing generators based on their watch lists then works correctly, so we no longer need the step where we reenqueue all generators. - In #195, it was pointed out that this slows down witness generation a bit. I moved the indexing code to preprocessing, so the prover is a bit faster (~7ms for me). * Outdated comment * Panic instead of infinite loop if we get stuck * BTree * fmt --- src/gates/gmimc.rs | 39 +++++++++------------- src/iop/challenger.rs | 19 +++-------- src/iop/generator.rs | 64 +++++++++++++++++------------------- src/plonk/circuit_builder.rs | 21 +++++++++++- src/plonk/circuit_data.rs | 4 +++ src/plonk/prover.rs | 13 ++------ 6 files changed, 77 insertions(+), 83 deletions(-) diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index ad8a8281..9df6360c 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -336,62 +336,53 @@ mod tests { use crate::field::crandall_field::CrandallField; use crate::field::field_types::Field; - use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::gmimc::GMiMCGate; use crate::hash::gmimc::GMiMC; use crate::iop::generator::generate_partial_witness; - use crate::iop::target::Target; use crate::iop::wire::Wire; - use crate::iop::witness::{PartialWitness, PartitionWitness, Witness}; - use crate::util::timing::TimingTree; + use crate::iop::witness::{PartialWitness, Witness}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; #[test] fn generated_output() { type F = CrandallField; const WIDTH: usize = 12; + + let config = CircuitConfig::large_config(); + let mut builder = CircuitBuilder::new(config); type Gate = GMiMCGate; let gate = Gate::new(); + let gate_index = builder.add_gate(gate, vec![]); + let circuit = builder.build_prover(); let permutation_inputs = (0..WIDTH).map(F::from_canonical_usize).collect::>(); - let mut witness = PartialWitness::new(); - witness.set_wire( + let mut inputs = PartialWitness::new(); + inputs.set_wire( Wire { - gate: 0, + gate: gate_index, input: Gate::WIRE_SWAP, }, F::ZERO, ); for i in 0..WIDTH { - witness.set_wire( + inputs.set_wire( Wire { - gate: 0, + gate: gate_index, input: Gate::wire_input(i), }, permutation_inputs[i], ); } - let mut partition_witness = PartitionWitness::new(gate.num_wires(), gate.num_wires(), 1, 0); - for input in 0..gate.num_wires() { - partition_witness.add(Target::Wire(Wire { gate: 0, input })); - } - for (&t, &v) in witness.target_values.iter() { - partition_witness.set_target(t, v); - } - let generators = gate.generators(0, &[]); - generate_partial_witness( - &mut partition_witness, - &generators, - &mut TimingTree::default(), - ); + let witness = generate_partial_witness(inputs, &circuit.prover_only); let expected_outputs: [F; WIDTH] = F::gmimc_permute_naive(permutation_inputs.try_into().unwrap()); - for i in 0..WIDTH { - let out = partition_witness.get_wire(Wire { + let out = witness.get_wire(Wire { gate: 0, input: Gate::wire_output(i), }); diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index f8b918ea..47d57db8 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -330,10 +330,9 @@ mod tests { use crate::iop::challenger::{Challenger, RecursiveChallenger}; use crate::iop::generator::generate_partial_witness; use crate::iop::target::Target; - use crate::iop::witness::Witness; + use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; - use crate::util::timing::TimingTree; #[test] fn no_duplicate_challenges() { @@ -377,11 +376,7 @@ mod tests { outputs_per_round.push(challenger.get_n_challenges(num_outputs_per_round[r])); } - let config = CircuitConfig { - num_wires: 12 + 12 + 1 + 101, - num_routed_wires: 27, - ..CircuitConfig::default() - }; + let config = CircuitConfig::large_config(); let mut builder = CircuitBuilder::::new(config.clone()); let mut recursive_challenger = RecursiveChallenger::new(&mut builder); let mut recursive_outputs_per_round: Vec> = Vec::new(); @@ -392,15 +387,11 @@ mod tests { ); } let circuit = builder.build(); - let mut partition_witness = circuit.prover_only.partition_witness.clone(); - generate_partial_witness( - &mut partition_witness, - &circuit.prover_only.generators, - &mut TimingTree::default(), - ); + let inputs = PartialWitness::new(); + let witness = generate_partial_witness(inputs, &circuit.prover_only); let recursive_output_values_per_round: Vec> = recursive_outputs_per_round .iter() - .map(|outputs| partition_witness.get_targets(outputs)) + .map(|outputs| witness.get_targets(outputs)) .collect(); assert_eq!(outputs_per_round, recursive_output_values_per_round); diff --git a/src/iop/generator.rs b/src/iop/generator.rs index 483a7a62..434e18ca 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -3,31 +3,26 @@ use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field_types::Field; +use crate::field::field_types::{Field, RichField}; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; -use crate::iop::witness::{PartitionWitness, Witness}; -use crate::timed; -use crate::util::timing::TimingTree; +use crate::iop::witness::{PartialWitness, PartitionWitness, Witness}; +use crate::plonk::circuit_data::ProverOnlyCircuitData; /// Given a `PartitionWitness` that has only inputs set, populates the rest of the witness using the /// given set of generators. -pub(crate) fn generate_partial_witness( - witness: &mut PartitionWitness, - generators: &[Box>], - timing: &mut TimingTree, -) { - let max_target_index = witness.forest.len(); - // Index generator indices by their watched targets. - let mut generator_indices_by_watches = vec![Vec::new(); max_target_index]; - timed!(timing, "index generators by their watched targets", { - for (i, generator) in generators.iter().enumerate() { - for watch in generator.watch_list() { - generator_indices_by_watches[witness.target_index(watch)].push(i); - } - } - }); +pub(crate) fn generate_partial_witness, const D: usize>( + inputs: PartialWitness, + prover_data: &ProverOnlyCircuitData, +) -> PartitionWitness { + let generators = &prover_data.generators; + let generator_indices_by_watches = &prover_data.generator_indices_by_watches; + + let mut witness = prover_data.partition_witness.clone(); + for (t, v) in inputs.target_values.into_iter() { + witness.set_target(t, v); + } // Build a list of "pending" generators which are queued to be run. Initially, all generators // are queued. @@ -39,8 +34,8 @@ pub(crate) fn generate_partial_witness( let mut buffer = GeneratedValues::empty(); - // Keep running generators until all generators have been run. - while remaining_generators > 0 { + // Keep running generators until we fail to make progress. + while !pending_generator_indices.is_empty() { let mut next_pending_generator_indices = Vec::new(); for &generator_idx in &pending_generator_indices { @@ -56,11 +51,12 @@ pub(crate) fn generate_partial_witness( // Enqueue unfinished generators that were watching one of the newly populated targets. for &(watch, _) in &buffer.target_values { - for &watching_generator_idx in - &generator_indices_by_watches[witness.target_index(watch)] - { - if !generator_is_expired[watching_generator_idx] { - next_pending_generator_indices.push(watching_generator_idx); + let opt_watchers = generator_indices_by_watches.get(&witness.target_index(watch)); + if let Some(watchers) = opt_watchers { + for &watching_generator_idx in watchers { + if !generator_is_expired[watching_generator_idx] { + next_pending_generator_indices.push(watching_generator_idx); + } } } } @@ -68,14 +64,16 @@ pub(crate) fn generate_partial_witness( witness.extend(buffer.target_values.drain(..)); } - // If we still need to run som generators, but none were enqueued, we enqueue all generators. - pending_generator_indices = - if remaining_generators > 0 && next_pending_generator_indices.is_empty() { - (0..generators.len()).collect() - } else { - next_pending_generator_indices - }; + pending_generator_indices = next_pending_generator_indices; } + + assert_eq!( + remaining_generators, 0, + "{} generators weren't run", + remaining_generators + ); + + witness } /// A generator participates in the generation of the witness. diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 9c34df15..0627c91d 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::convert::TryInto; use std::time::Instant; @@ -619,8 +619,27 @@ impl, const D: usize> CircuitBuilder { constants_sigmas_cap: constants_sigmas_cap.clone(), }; + // Index generator indices by their watched targets. + let max_target_index = partition_witness.forest.len(); + let mut generator_indices_by_watches = BTreeMap::new(); + for (i, generator) in self.generators.iter().enumerate() { + for watch in generator.watch_list() { + let watch_index = partition_witness.target_index(watch); + let watch_rep_index = partition_witness.forest[watch_index].parent; + generator_indices_by_watches + .entry(watch_rep_index) + .or_insert(vec![]) + .push(i); + } + } + for indices in generator_indices_by_watches.values_mut() { + indices.dedup(); + indices.shrink_to_fit(); + } + let prover_only = ProverOnlyCircuitData { generators: self.generators, + generator_indices_by_watches, constants_sigmas_commitment, sigmas: transpose_poly_values(sigma_vecs), subgroup, diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 4ec27cc7..21dfc28b 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::ops::{Range, RangeFrom}; use anyhow::Result; @@ -141,6 +142,9 @@ impl, const D: usize> VerifierCircuitData { /// Circuit data required by the prover, but not the verifier. pub(crate) struct ProverOnlyCircuitData, const D: usize> { pub generators: Vec>>, + /// Generator indices (within the `Vec` above), indexed by the representative of each target + /// they watch. + pub generator_indices_by_watches: BTreeMap>, /// Commitments to the constants polynomials and sigma polynomials. pub constants_sigmas_commitment: PolynomialBatchCommitment, /// The transpose of the list of sigma polynomials. diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 7f16af04..7f609826 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -33,19 +33,10 @@ pub(crate) fn prove, const D: usize>( let quotient_degree = common_data.quotient_degree(); let degree = common_data.degree(); - let mut partition_witness = prover_data.partition_witness.clone(); - timed!( - timing, - "fill partition witness", - for (t, v) in inputs.target_values.into_iter() { - partition_witness.set_target(t, v); - } - ); - - timed!( + let partition_witness = timed!( timing, &format!("run {} generators", prover_data.generators.len()), - generate_partial_witness(&mut partition_witness, &prover_data.generators, &mut timing) + generate_partial_witness(inputs, &prover_data) ); let public_inputs = partition_witness.get_targets(&prover_data.public_inputs);