diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index ad8a8281..9df6360c 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -336,62 +336,53 @@ mod tests { use crate::field::crandall_field::CrandallField; use crate::field::field_types::Field; - use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::gmimc::GMiMCGate; use crate::hash::gmimc::GMiMC; use crate::iop::generator::generate_partial_witness; - use crate::iop::target::Target; use crate::iop::wire::Wire; - use crate::iop::witness::{PartialWitness, PartitionWitness, Witness}; - use crate::util::timing::TimingTree; + use crate::iop::witness::{PartialWitness, Witness}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; #[test] fn generated_output() { type F = CrandallField; const WIDTH: usize = 12; + + let config = CircuitConfig::large_config(); + let mut builder = CircuitBuilder::new(config); type Gate = GMiMCGate; let gate = Gate::new(); + let gate_index = builder.add_gate(gate, vec![]); + let circuit = builder.build_prover(); let permutation_inputs = (0..WIDTH).map(F::from_canonical_usize).collect::>(); - let mut witness = PartialWitness::new(); - witness.set_wire( + let mut inputs = PartialWitness::new(); + inputs.set_wire( Wire { - gate: 0, + gate: gate_index, input: Gate::WIRE_SWAP, }, F::ZERO, ); for i in 0..WIDTH { - witness.set_wire( + inputs.set_wire( Wire { - gate: 0, + gate: gate_index, input: Gate::wire_input(i), }, permutation_inputs[i], ); } - let mut partition_witness = PartitionWitness::new(gate.num_wires(), gate.num_wires(), 1, 0); - for input in 0..gate.num_wires() { - partition_witness.add(Target::Wire(Wire { gate: 0, input })); - } - for (&t, &v) in witness.target_values.iter() { - partition_witness.set_target(t, v); - } - let generators = gate.generators(0, &[]); - generate_partial_witness( - &mut partition_witness, - &generators, - &mut TimingTree::default(), - ); + let witness = generate_partial_witness(inputs, &circuit.prover_only); let expected_outputs: [F; WIDTH] = F::gmimc_permute_naive(permutation_inputs.try_into().unwrap()); - for i in 0..WIDTH { - let out = partition_witness.get_wire(Wire { + let out = witness.get_wire(Wire { gate: 0, input: Gate::wire_output(i), }); diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index f8b918ea..47d57db8 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -330,10 +330,9 @@ mod tests { use crate::iop::challenger::{Challenger, RecursiveChallenger}; use crate::iop::generator::generate_partial_witness; use crate::iop::target::Target; - use crate::iop::witness::Witness; + use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; - use crate::util::timing::TimingTree; #[test] fn no_duplicate_challenges() { @@ -377,11 +376,7 @@ mod tests { outputs_per_round.push(challenger.get_n_challenges(num_outputs_per_round[r])); } - let config = CircuitConfig { - num_wires: 12 + 12 + 1 + 101, - num_routed_wires: 27, - ..CircuitConfig::default() - }; + let config = CircuitConfig::large_config(); let mut builder = CircuitBuilder::::new(config.clone()); let mut recursive_challenger = RecursiveChallenger::new(&mut builder); let mut recursive_outputs_per_round: Vec> = Vec::new(); @@ -392,15 +387,11 @@ mod tests { ); } let circuit = builder.build(); - let mut partition_witness = circuit.prover_only.partition_witness.clone(); - generate_partial_witness( - &mut partition_witness, - &circuit.prover_only.generators, - &mut TimingTree::default(), - ); + let inputs = PartialWitness::new(); + let witness = generate_partial_witness(inputs, &circuit.prover_only); let recursive_output_values_per_round: Vec> = recursive_outputs_per_round .iter() - .map(|outputs| partition_witness.get_targets(outputs)) + .map(|outputs| witness.get_targets(outputs)) .collect(); assert_eq!(outputs_per_round, recursive_output_values_per_round); diff --git a/src/iop/generator.rs b/src/iop/generator.rs index 483a7a62..434e18ca 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -3,31 +3,26 @@ use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field_types::Field; +use crate::field::field_types::{Field, RichField}; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; -use crate::iop::witness::{PartitionWitness, Witness}; -use crate::timed; -use crate::util::timing::TimingTree; +use crate::iop::witness::{PartialWitness, PartitionWitness, Witness}; +use crate::plonk::circuit_data::ProverOnlyCircuitData; /// Given a `PartitionWitness` that has only inputs set, populates the rest of the witness using the /// given set of generators. -pub(crate) fn generate_partial_witness( - witness: &mut PartitionWitness, - generators: &[Box>], - timing: &mut TimingTree, -) { - let max_target_index = witness.forest.len(); - // Index generator indices by their watched targets. - let mut generator_indices_by_watches = vec![Vec::new(); max_target_index]; - timed!(timing, "index generators by their watched targets", { - for (i, generator) in generators.iter().enumerate() { - for watch in generator.watch_list() { - generator_indices_by_watches[witness.target_index(watch)].push(i); - } - } - }); +pub(crate) fn generate_partial_witness, const D: usize>( + inputs: PartialWitness, + prover_data: &ProverOnlyCircuitData, +) -> PartitionWitness { + let generators = &prover_data.generators; + let generator_indices_by_watches = &prover_data.generator_indices_by_watches; + + let mut witness = prover_data.partition_witness.clone(); + for (t, v) in inputs.target_values.into_iter() { + witness.set_target(t, v); + } // Build a list of "pending" generators which are queued to be run. Initially, all generators // are queued. @@ -39,8 +34,8 @@ pub(crate) fn generate_partial_witness( let mut buffer = GeneratedValues::empty(); - // Keep running generators until all generators have been run. - while remaining_generators > 0 { + // Keep running generators until we fail to make progress. + while !pending_generator_indices.is_empty() { let mut next_pending_generator_indices = Vec::new(); for &generator_idx in &pending_generator_indices { @@ -56,11 +51,12 @@ pub(crate) fn generate_partial_witness( // Enqueue unfinished generators that were watching one of the newly populated targets. for &(watch, _) in &buffer.target_values { - for &watching_generator_idx in - &generator_indices_by_watches[witness.target_index(watch)] - { - if !generator_is_expired[watching_generator_idx] { - next_pending_generator_indices.push(watching_generator_idx); + let opt_watchers = generator_indices_by_watches.get(&witness.target_index(watch)); + if let Some(watchers) = opt_watchers { + for &watching_generator_idx in watchers { + if !generator_is_expired[watching_generator_idx] { + next_pending_generator_indices.push(watching_generator_idx); + } } } } @@ -68,14 +64,16 @@ pub(crate) fn generate_partial_witness( witness.extend(buffer.target_values.drain(..)); } - // If we still need to run som generators, but none were enqueued, we enqueue all generators. - pending_generator_indices = - if remaining_generators > 0 && next_pending_generator_indices.is_empty() { - (0..generators.len()).collect() - } else { - next_pending_generator_indices - }; + pending_generator_indices = next_pending_generator_indices; } + + assert_eq!( + remaining_generators, 0, + "{} generators weren't run", + remaining_generators + ); + + witness } /// A generator participates in the generation of the witness. diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 9c34df15..0627c91d 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::convert::TryInto; use std::time::Instant; @@ -619,8 +619,27 @@ impl, const D: usize> CircuitBuilder { constants_sigmas_cap: constants_sigmas_cap.clone(), }; + // Index generator indices by their watched targets. + let max_target_index = partition_witness.forest.len(); + let mut generator_indices_by_watches = BTreeMap::new(); + for (i, generator) in self.generators.iter().enumerate() { + for watch in generator.watch_list() { + let watch_index = partition_witness.target_index(watch); + let watch_rep_index = partition_witness.forest[watch_index].parent; + generator_indices_by_watches + .entry(watch_rep_index) + .or_insert(vec![]) + .push(i); + } + } + for indices in generator_indices_by_watches.values_mut() { + indices.dedup(); + indices.shrink_to_fit(); + } + let prover_only = ProverOnlyCircuitData { generators: self.generators, + generator_indices_by_watches, constants_sigmas_commitment, sigmas: transpose_poly_values(sigma_vecs), subgroup, diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 4ec27cc7..21dfc28b 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::ops::{Range, RangeFrom}; use anyhow::Result; @@ -141,6 +142,9 @@ impl, const D: usize> VerifierCircuitData { /// Circuit data required by the prover, but not the verifier. pub(crate) struct ProverOnlyCircuitData, const D: usize> { pub generators: Vec>>, + /// Generator indices (within the `Vec` above), indexed by the representative of each target + /// they watch. + pub generator_indices_by_watches: BTreeMap>, /// Commitments to the constants polynomials and sigma polynomials. pub constants_sigmas_commitment: PolynomialBatchCommitment, /// The transpose of the list of sigma polynomials. diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 7f16af04..7f609826 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -33,19 +33,10 @@ pub(crate) fn prove, const D: usize>( let quotient_degree = common_data.quotient_degree(); let degree = common_data.degree(); - let mut partition_witness = prover_data.partition_witness.clone(); - timed!( - timing, - "fill partition witness", - for (t, v) in inputs.target_values.into_iter() { - partition_witness.set_target(t, v); - } - ); - - timed!( + let partition_witness = timed!( timing, &format!("run {} generators", prover_data.generators.len()), - generate_partial_witness(&mut partition_witness, &prover_data.generators, &mut timing) + generate_partial_witness(inputs, &prover_data) ); let public_inputs = partition_witness.get_targets(&prover_data.public_inputs);