diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index fe3f7d70..311ba841 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -364,7 +364,14 @@ mod tests { } let generators = gate.generators(0, &[]); - generate_partial_witness(&mut witness, &generators, &mut TimingTree::default()); + generate_partial_witness( + &mut witness, + &generators, + gate.num_wires(), + 1, + 1, + &mut TimingTree::default(), + ); let expected_outputs: [F; W] = gmimc_permute_naive(permutation_inputs.try_into().unwrap(), constants); diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index 53867de7..72314048 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -399,7 +399,7 @@ mod tests { num_routed_wires: 27, ..CircuitConfig::default() }; - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::::new(config.clone()); let mut recursive_challenger = RecursiveChallenger::new(&mut builder); let mut recursive_outputs_per_round: Vec> = Vec::new(); for (r, inputs) in inputs_per_round.iter().enumerate() { @@ -413,6 +413,9 @@ mod tests { generate_partial_witness( &mut witness, &circuit.prover_only.generators, + config.num_wires, + circuit.common.degree(), + circuit.prover_only.num_virtual_targets, &mut TimingTree::default(), ); let recursive_output_values_per_round: Vec> = recursive_outputs_per_round diff --git a/src/iop/generator.rs b/src/iop/generator.rs index ad492810..6ab1b412 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::convert::identity; use std::fmt::Debug; @@ -17,17 +16,26 @@ use crate::util::timing::TimingTree; pub(crate) fn generate_partial_witness( witness: &mut PartialWitness, generators: &[Box>], + num_wires: usize, + degree: usize, + max_virtual_target: usize, timing: &mut TimingTree, ) { + let target_index = |t: Target| -> usize { + match t { + Target::Wire(Wire { gate, input }) => gate * num_wires + input, + Target::VirtualTarget { index } => degree * num_wires + index, + } + }; + let max_target_index = target_index(Target::VirtualTarget { + index: max_virtual_target, + }); // Index generator indices by their watched targets. - let mut generator_indices_by_watches = HashMap::new(); + let mut generator_indices_by_watches = vec![Vec::new(); max_target_index]; timed!(timing, "index generators by their watched targets", { for (i, generator) in generators.iter().enumerate() { for watch in generator.watch_list() { - generator_indices_by_watches - .entry(watch) - .or_insert_with(Vec::new) - .push(i); + generator_indices_by_watches[target_index(watch)].push(i); } } }); @@ -56,11 +64,9 @@ pub(crate) fn generate_partial_witness( } // Enqueue unfinished generators that were watching one of the newly populated targets. - for (watch, _) in &buffer.target_values { - if let Some(watching_generator_indices) = generator_indices_by_watches.get(watch) { - for &watching_generator_idx in watching_generator_indices { - next_pending_generator_indices.push(watching_generator_idx); - } + for &(watch, _) in &buffer.target_values { + for &watching_generator_idx in &generator_indices_by_watches[target_index(watch)] { + next_pending_generator_indices.push(watching_generator_idx); } } diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 50e5f3e6..f4eaa881 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -576,6 +576,7 @@ impl, const D: usize> CircuitBuilder { gate_instances: self.gate_instances, public_inputs: self.public_inputs, marked_targets: self.marked_targets, + num_virtual_targets: self.virtual_target_index, }; // The HashSet of gates will have a non-deterministic order. When converting to a Vec, we diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 576bedfd..48ec13be 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -136,6 +136,8 @@ pub(crate) struct ProverOnlyCircuitData, const D: usize> { pub public_inputs: Vec, /// A vector of marked targets. The values assigned to these targets will be displayed by the prover. pub marked_targets: Vec>, + /// Number of virtual targets used in the circuit. + pub num_virtual_targets: usize, } /// Circuit data required by the verifier, but not the prover. diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 06670a08..de85d94b 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -37,7 +37,14 @@ pub(crate) fn prove, const D: usize>( timed!( timing, &format!("run {} generators", prover_data.generators.len()), - generate_partial_witness(&mut partial_witness, &prover_data.generators, &mut timing) + generate_partial_witness( + &mut partial_witness, + &prover_data.generators, + config.num_wires, + degree, + prover_data.num_virtual_targets, + &mut timing + ) ); let public_inputs = partial_witness.get_targets(&prover_data.public_inputs);