diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 5c031cfe..4ef87be6 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -375,7 +375,7 @@ mod tests { ); } - let witness = generate_partial_witness(inputs, &circuit.prover_only); + let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common); let expected_outputs: [F; WIDTH] = F::gmimc_permute_naive(permutation_inputs.try_into().unwrap()); diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index c45be25f..b6902ada 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -541,7 +541,7 @@ mod tests { ); } - let witness = generate_partial_witness(inputs, &circuit.prover_only); + let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common); let expected_outputs: [F; WIDTH] = F::poseidon(permutation_inputs.try_into().unwrap()); for i in 0..WIDTH { diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index d33c19ce..afecb1b2 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -388,7 +388,7 @@ mod tests { } let circuit = builder.build(); let inputs = PartialWitness::new(); - let witness = generate_partial_witness(inputs, &circuit.prover_only); + let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common); let recursive_output_values_per_round: Vec> = recursive_outputs_per_round .iter() .map(|outputs| witness.get_targets(outputs)) diff --git a/src/iop/generator.rs b/src/iop/generator.rs index dd876c6c..a64afed6 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -8,18 +8,26 @@ use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::{PartialWitness, PartitionWitness, Witness}; -use crate::plonk::circuit_data::ProverOnlyCircuitData; +use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; /// Given a `PartitionWitness` that has only inputs set, populates the rest of the witness using the /// given set of generators. -pub(crate) fn generate_partial_witness, const D: usize>( +pub(crate) fn generate_partial_witness<'a, F: RichField + Extendable, const D: usize>( inputs: PartialWitness, - prover_data: &ProverOnlyCircuitData, -) -> PartitionWitness { + prover_data: &'a ProverOnlyCircuitData, + common_data: &'a CommonCircuitData, +) -> PartitionWitness<'a, F> { + let config = &common_data.config; let generators = &prover_data.generators; let generator_indices_by_watches = &prover_data.generator_indices_by_watches; - let mut witness = prover_data.partition_witness.clone(); + let mut witness = PartitionWitness::new( + config.num_wires, + common_data.degree(), + common_data.num_virtual_targets, + &prover_data.representative_map, + ); + for (t, v) in inputs.target_values.into_iter() { witness.set_target(t, v); } @@ -51,7 +59,10 @@ pub(crate) fn generate_partial_witness, const D: us // Merge any generated values into our witness, and get a list of newly-populated // targets' representatives. - let new_target_reps = witness.extend_returning_parents(buffer.target_values.drain(..)); + let new_target_reps = buffer + .target_values + .drain(..) + .flat_map(|(t, v)| witness.set_target_returning_parent(t, v)); // Enqueue unfinished generators that were watching one of the newly populated targets. for watch in new_target_reps { diff --git a/src/iop/target.rs b/src/iop/target.rs index 877da5b8..8d4cbcfb 100644 --- a/src/iop/target.rs +++ b/src/iop/target.rs @@ -30,6 +30,13 @@ impl Target { pub fn wires_from_range(gate: usize, range: Range) -> Vec { range.map(|i| Self::wire(gate, i)).collect() } + + pub fn index(&self, num_wires: usize, degree: usize) -> usize { + match self { + Target::Wire(Wire { gate, input }) => gate * num_wires + input, + Target::VirtualTarget { index } => degree * num_wires + index, + } + } } /// A `Target` which has already been constrained such that it can only be 0 or 1. diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 3e884891..13d1fa2c 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -9,7 +9,6 @@ use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; -use crate::plonk::permutation_argument::ForestNode; /// A witness holds information on the values of targets in a circuit. pub trait Witness { @@ -189,30 +188,37 @@ impl Witness for PartialWitness { /// `PartitionWitness` holds a disjoint-set forest of the targets respecting a circuit's copy constraints. /// The value of a target is defined to be the value of its root in the forest. #[derive(Clone)] -pub struct PartitionWitness { - pub forest: Vec>, +pub struct PartitionWitness<'a, F: Field> { + pub values: Vec>, + pub representative_map: &'a [usize], pub num_wires: usize, - pub num_routed_wires: usize, pub degree: usize, } -impl Witness for PartitionWitness { - fn try_get_target(&self, target: Target) -> Option { - let parent_index = self.forest[self.target_index(target)].parent; - self.forest[parent_index].value +impl<'a, F: Field> PartitionWitness<'a, F> { + pub fn new( + num_wires: usize, + degree: usize, + num_virtual_targets: usize, + representative_map: &'a [usize], + ) -> Self { + Self { + values: vec![None; degree * num_wires + num_virtual_targets], + representative_map, + num_wires, + degree, + } } - fn set_target(&mut self, target: Target, value: F) { - self.set_target_returning_parent(target, value); - } -} - -impl PartitionWitness { /// Set a `Target`. On success, returns the representative index of the newly-set target. If the /// target was already set, returns `None`. - fn set_target_returning_parent(&mut self, target: Target, value: F) -> Option { - let parent_index = self.forest[self.target_index(target)].parent; - let parent_value = &mut self.forest[parent_index].value; + pub(crate) fn set_target_returning_parent( + &mut self, + target: Target, + value: F, + ) -> Option { + let parent_index = self.representative_map[self.target_index(target)]; + let parent_value = &mut self.values[parent_index]; if let Some(old_value) = *parent_value { assert_eq!( value, old_value, @@ -226,19 +232,8 @@ impl PartitionWitness { } } - /// Returns the representative indices of any newly-set targets. - pub(crate) fn extend_returning_parents<'a, I: 'a + Iterator>( - &'a mut self, - pairs: I, - ) -> impl Iterator + 'a { - pairs.flat_map(move |(t, v)| self.set_target_returning_parent(t, v)) - } - - pub fn target_index(&self, target: Target) -> usize { - match target { - Target::Wire(Wire { gate, input }) => gate * self.num_wires + input, - Target::VirtualTarget { index } => self.degree * self.num_wires + index, - } + pub(crate) fn target_index(&self, target: Target) -> usize { + target.index(self.num_wires, self.degree) } pub fn full_witness(self) -> MatrixWitness { @@ -255,3 +250,14 @@ impl PartitionWitness { MatrixWitness { wire_values } } } + +impl<'a, F: Field> Witness for PartitionWitness<'a, F> { + fn try_get_target(&self, target: Target) -> Option { + let parent_index = self.representative_map[self.target_index(target)]; + self.values[parent_index] + } + + fn set_target(&mut self, target: Target, value: F) { + self.set_target_returning_parent(target, value); + } +} diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index db9e855b..035a6d90 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -25,12 +25,12 @@ use crate::iop::generator::{ }; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; -use crate::iop::witness::PartitionWitness; use crate::plonk::circuit_data::{ CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData, VerifierCircuitData, VerifierOnlyCircuitData, }; use crate::plonk::copy_constraint::CopyConstraint; +use crate::plonk::permutation_argument::Forest; use crate::plonk::plonk_common::PlonkPolynomials; use crate::polynomial::polynomial::PolynomialValues; use crate::util::context_tree::ContextTree; @@ -456,40 +456,37 @@ impl, const D: usize> CircuitBuilder { .collect() } - fn sigma_vecs( - &self, - k_is: &[F], - subgroup: &[F], - ) -> (Vec>, PartitionWitness) { + fn sigma_vecs(&self, k_is: &[F], subgroup: &[F]) -> (Vec>, Forest) { let degree = self.gate_instances.len(); let degree_log = log2_strict(degree); - let mut partition_witness = PartitionWitness::new( - self.config.num_wires, - self.config.num_routed_wires, + let config = &self.config; + let mut forest = Forest::new( + config.num_wires, + config.num_routed_wires, degree, self.virtual_target_index, ); for gate in 0..degree { - for input in 0..self.config.num_wires { - partition_witness.add(Target::Wire(Wire { gate, input })); + for input in 0..config.num_wires { + forest.add(Target::Wire(Wire { gate, input })); } } for index in 0..self.virtual_target_index { - partition_witness.add(Target::VirtualTarget { index }); + forest.add(Target::VirtualTarget { index }); } for &CopyConstraint { pair: (a, b), .. } in &self.copy_constraints { - partition_witness.merge(a, b); + forest.merge(a, b); } - partition_witness.compress_paths(); + forest.compress_paths(); - let wire_partition = partition_witness.wire_partition(); + let wire_partition = forest.wire_partition(); ( wire_partition.get_sigma_polys(degree_log, k_is, subgroup), - partition_witness, + forest, ) } @@ -607,7 +604,7 @@ impl, const D: usize> CircuitBuilder { let constant_vecs = self.constant_polys(&prefixed_gates, num_constants); let k_is = get_unique_coset_shifts(degree, self.config.num_routed_wires); - let (sigma_vecs, partition_witness) = self.sigma_vecs(&k_is, &subgroup); + let (sigma_vecs, forest) = self.sigma_vecs(&k_is, &subgroup); // Precompute FFT roots. let max_fft_points = @@ -633,8 +630,8 @@ impl, const D: usize> CircuitBuilder { let mut generator_indices_by_watches = BTreeMap::new(); for (i, generator) in self.generators.iter().enumerate() { for watch in generator.watch_list() { - let watch_index = partition_witness.target_index(watch); - let watch_rep_index = partition_witness.forest[watch_index].parent; + let watch_index = forest.target_index(watch); + let watch_rep_index = forest.parents[watch_index]; generator_indices_by_watches .entry(watch_rep_index) .or_insert(vec![]) @@ -654,7 +651,7 @@ impl, const D: usize> CircuitBuilder { subgroup, public_inputs: self.public_inputs, marked_targets: self.marked_targets, - partition_witness, + representative_map: forest.parents, fft_root_table: Some(fft_root_table), }; @@ -686,6 +683,7 @@ impl, const D: usize> CircuitBuilder { quotient_degree_factor, num_gate_constraints, num_constants, + num_virtual_targets: self.virtual_target_index, k_is, num_partial_products, circuit_digest, diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 82f1c1cd..644ab370 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -13,7 +13,7 @@ use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::generator::WitnessGenerator; use crate::iop::target::Target; -use crate::iop::witness::{PartialWitness, PartitionWitness}; +use crate::iop::witness::PartialWitness; use crate::plonk::proof::ProofWithPublicInputs; use crate::plonk::prover::prove; use crate::plonk::verifier::verify; @@ -156,8 +156,9 @@ pub(crate) struct ProverOnlyCircuitData, const D: u pub public_inputs: Vec, /// A vector of marked targets. The values assigned to these targets will be displayed by the prover. pub marked_targets: Vec>, - /// Partial witness holding the copy constraints information. - pub partition_witness: PartitionWitness, + /// A map from each `Target`'s index to the index of its representative in the disjoint-set + /// forest. + pub representative_map: Vec, /// Pre-computed roots for faster FFT. pub fft_root_table: Option>, } @@ -188,6 +189,8 @@ pub struct CommonCircuitData, const D: usize> { /// The number of constant wires. pub(crate) num_constants: usize, + pub(crate) num_virtual_targets: usize, + /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. pub(crate) k_is: Vec, diff --git a/src/plonk/permutation_argument.rs b/src/plonk/permutation_argument.rs index 761c1bd8..acf3686c 100644 --- a/src/plonk/permutation_argument.rs +++ b/src/plonk/permutation_argument.rs @@ -1,55 +1,55 @@ use std::collections::HashMap; -use std::fmt::Debug; use rayon::prelude::*; use crate::field::field_types::Field; use crate::iop::target::Target; use crate::iop::wire::Wire; -use crate::iop::witness::PartitionWitness; use crate::polynomial::polynomial::PolynomialValues; -/// Node in the Disjoint Set Forest. -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ForestNode { - pub parent: usize, - pub size: usize, - pub value: Option, +/// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure. +pub struct Forest { + /// A map of parent pointers, stored as indices. + pub(crate) parents: Vec, + + num_wires: usize, + num_routed_wires: usize, + degree: usize, } -/// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure. -impl PartitionWitness { +impl Forest { pub fn new( num_wires: usize, num_routed_wires: usize, degree: usize, num_virtual_targets: usize, ) -> Self { + let capacity = num_wires * degree + num_virtual_targets; Self { - forest: Vec::with_capacity(degree * num_wires + num_virtual_targets), + parents: Vec::with_capacity(capacity), num_wires, num_routed_wires, degree, } } + pub(crate) fn target_index(&self, target: Target) -> usize { + target.index(self.num_wires, self.degree) + } + /// Add a new partition with a single member. pub fn add(&mut self, t: Target) { - let index = self.forest.len(); + let index = self.parents.len(); debug_assert_eq!(self.target_index(t), index); - self.forest.push(ForestNode { - parent: index, - size: 1, - value: None, - }); + self.parents.push(index); } /// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives. pub fn find(&mut self, x_index: usize) -> usize { - let x = self.forest[x_index]; - if x.parent != x_index { - let root_index = self.find(x.parent); - self.forest[x_index].parent = root_index; + let x_parent = self.parents[x_index]; + if x_parent != x_index { + let root_index = self.find(x_parent); + self.parents[x_index] = root_index; root_index } else { x_index @@ -65,25 +65,13 @@ impl PartitionWitness { return; } - let mut x = self.forest[x_index]; - let mut y = self.forest[y_index]; - - if x.size >= y.size { - y.parent = x_index; - x.size += y.size; - } else { - x.parent = y_index; - y.size += x.size; - } - - self.forest[x_index] = x; - self.forest[y_index] = y; + self.parents[y_index] = x_index; } /// Compress all paths. After calling this, every `parent` value will point to the node's /// representative. pub(crate) fn compress_paths(&mut self) { - for i in 0..self.forest.len() { + for i in 0..self.parents.len() { self.find(i); } } @@ -96,8 +84,8 @@ impl PartitionWitness { for input in 0..self.num_routed_wires { let w = Wire { gate, input }; let t = Target::Wire(w); - let x = self.forest[self.target_index(t)]; - partition.entry(x.parent).or_default().push(w); + let x_parent = self.parents[self.target_index(t)]; + partition.entry(x_parent).or_default().push(w); } } diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 9f4303ae..19852cff 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -36,7 +36,7 @@ pub(crate) fn prove, const D: usize>( let partition_witness = timed!( timing, &format!("run {} generators", prover_data.generators.len()), - generate_partial_witness(inputs, &prover_data) + generate_partial_witness(inputs, prover_data, common_data) ); let public_inputs = partition_witness.get_targets(&prover_data.public_inputs);