From 3f226632961994565582c0d9775d2d0152b54532 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 28 Sep 2021 22:31:20 -0700 Subject: [PATCH] Split up `PartitionWitness` data (#273) * Split up `PartitionWitness` data This addresses two minor inefficiencies: - Some preprocessed forest data was being cloned during proving. - Some of the `ForestNode` data (like node sizes) is only needed in preprocessing, not proving. It was taking up cache space during proving because it was interleaved with data that is used during proving (parents, values). Now `Forest` contains the disjoint-set forest. `PartitionWitness` is now mainly a Vec of target values; it also holds a reference to the (preprocessed) representative map. On my laptop, this speeds up witness generation ~12%, resulting in an overall ~0.5% speedup. * Feedback * No size data (#278) * No size data * feedback --- src/gates/gmimc.rs | 2 +- src/gates/poseidon.rs | 2 +- src/iop/challenger.rs | 2 +- src/iop/generator.rs | 23 ++++++++--- src/iop/target.rs | 7 ++++ src/iop/witness.rs | 66 +++++++++++++++++-------------- src/plonk/circuit_builder.rs | 38 +++++++++--------- src/plonk/circuit_data.rs | 9 +++-- src/plonk/permutation_argument.rs | 62 ++++++++++++----------------- src/plonk/prover.rs | 2 +- 10 files changed, 113 insertions(+), 100 deletions(-) diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 5c031cfe..4ef87be6 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -375,7 +375,7 @@ mod tests { ); } - let witness = generate_partial_witness(inputs, &circuit.prover_only); + let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common); let expected_outputs: [F; WIDTH] = F::gmimc_permute_naive(permutation_inputs.try_into().unwrap()); diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index c45be25f..b6902ada 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -541,7 +541,7 @@ mod tests { ); } - let witness = generate_partial_witness(inputs, &circuit.prover_only); + let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common); let expected_outputs: [F; WIDTH] = F::poseidon(permutation_inputs.try_into().unwrap()); for i in 0..WIDTH { diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index d33c19ce..afecb1b2 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -388,7 +388,7 @@ mod tests { } let circuit = builder.build(); let inputs = PartialWitness::new(); - let witness = generate_partial_witness(inputs, &circuit.prover_only); + let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common); let recursive_output_values_per_round: Vec> = recursive_outputs_per_round .iter() .map(|outputs| witness.get_targets(outputs)) diff --git a/src/iop/generator.rs b/src/iop/generator.rs index dd876c6c..a64afed6 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -8,18 +8,26 @@ use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::{PartialWitness, PartitionWitness, Witness}; -use crate::plonk::circuit_data::ProverOnlyCircuitData; +use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; /// Given a `PartitionWitness` that has only inputs set, populates the rest of the witness using the /// given set of generators. -pub(crate) fn generate_partial_witness, const D: usize>( +pub(crate) fn generate_partial_witness<'a, F: RichField + Extendable, const D: usize>( inputs: PartialWitness, - prover_data: &ProverOnlyCircuitData, -) -> PartitionWitness { + prover_data: &'a ProverOnlyCircuitData, + common_data: &'a CommonCircuitData, +) -> PartitionWitness<'a, F> { + let config = &common_data.config; let generators = &prover_data.generators; let generator_indices_by_watches = &prover_data.generator_indices_by_watches; - let mut witness = prover_data.partition_witness.clone(); + let mut witness = PartitionWitness::new( + config.num_wires, + common_data.degree(), + common_data.num_virtual_targets, + &prover_data.representative_map, + ); + for (t, v) in inputs.target_values.into_iter() { witness.set_target(t, v); } @@ -51,7 +59,10 @@ pub(crate) fn generate_partial_witness, const D: us // Merge any generated values into our witness, and get a list of newly-populated // targets' representatives. - let new_target_reps = witness.extend_returning_parents(buffer.target_values.drain(..)); + let new_target_reps = buffer + .target_values + .drain(..) + .flat_map(|(t, v)| witness.set_target_returning_parent(t, v)); // Enqueue unfinished generators that were watching one of the newly populated targets. for watch in new_target_reps { diff --git a/src/iop/target.rs b/src/iop/target.rs index 877da5b8..8d4cbcfb 100644 --- a/src/iop/target.rs +++ b/src/iop/target.rs @@ -30,6 +30,13 @@ impl Target { pub fn wires_from_range(gate: usize, range: Range) -> Vec { range.map(|i| Self::wire(gate, i)).collect() } + + pub fn index(&self, num_wires: usize, degree: usize) -> usize { + match self { + Target::Wire(Wire { gate, input }) => gate * num_wires + input, + Target::VirtualTarget { index } => degree * num_wires + index, + } + } } /// A `Target` which has already been constrained such that it can only be 0 or 1. diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 3e884891..13d1fa2c 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -9,7 +9,6 @@ use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; -use crate::plonk::permutation_argument::ForestNode; /// A witness holds information on the values of targets in a circuit. pub trait Witness { @@ -189,30 +188,37 @@ impl Witness for PartialWitness { /// `PartitionWitness` holds a disjoint-set forest of the targets respecting a circuit's copy constraints. /// The value of a target is defined to be the value of its root in the forest. #[derive(Clone)] -pub struct PartitionWitness { - pub forest: Vec>, +pub struct PartitionWitness<'a, F: Field> { + pub values: Vec>, + pub representative_map: &'a [usize], pub num_wires: usize, - pub num_routed_wires: usize, pub degree: usize, } -impl Witness for PartitionWitness { - fn try_get_target(&self, target: Target) -> Option { - let parent_index = self.forest[self.target_index(target)].parent; - self.forest[parent_index].value +impl<'a, F: Field> PartitionWitness<'a, F> { + pub fn new( + num_wires: usize, + degree: usize, + num_virtual_targets: usize, + representative_map: &'a [usize], + ) -> Self { + Self { + values: vec![None; degree * num_wires + num_virtual_targets], + representative_map, + num_wires, + degree, + } } - fn set_target(&mut self, target: Target, value: F) { - self.set_target_returning_parent(target, value); - } -} - -impl PartitionWitness { /// Set a `Target`. On success, returns the representative index of the newly-set target. If the /// target was already set, returns `None`. - fn set_target_returning_parent(&mut self, target: Target, value: F) -> Option { - let parent_index = self.forest[self.target_index(target)].parent; - let parent_value = &mut self.forest[parent_index].value; + pub(crate) fn set_target_returning_parent( + &mut self, + target: Target, + value: F, + ) -> Option { + let parent_index = self.representative_map[self.target_index(target)]; + let parent_value = &mut self.values[parent_index]; if let Some(old_value) = *parent_value { assert_eq!( value, old_value, @@ -226,19 +232,8 @@ impl PartitionWitness { } } - /// Returns the representative indices of any newly-set targets. - pub(crate) fn extend_returning_parents<'a, I: 'a + Iterator>( - &'a mut self, - pairs: I, - ) -> impl Iterator + 'a { - pairs.flat_map(move |(t, v)| self.set_target_returning_parent(t, v)) - } - - pub fn target_index(&self, target: Target) -> usize { - match target { - Target::Wire(Wire { gate, input }) => gate * self.num_wires + input, - Target::VirtualTarget { index } => self.degree * self.num_wires + index, - } + pub(crate) fn target_index(&self, target: Target) -> usize { + target.index(self.num_wires, self.degree) } pub fn full_witness(self) -> MatrixWitness { @@ -255,3 +250,14 @@ impl PartitionWitness { MatrixWitness { wire_values } } } + +impl<'a, F: Field> Witness for PartitionWitness<'a, F> { + fn try_get_target(&self, target: Target) -> Option { + let parent_index = self.representative_map[self.target_index(target)]; + self.values[parent_index] + } + + fn set_target(&mut self, target: Target, value: F) { + self.set_target_returning_parent(target, value); + } +} diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index db9e855b..035a6d90 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -25,12 +25,12 @@ use crate::iop::generator::{ }; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; -use crate::iop::witness::PartitionWitness; use crate::plonk::circuit_data::{ CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData, VerifierCircuitData, VerifierOnlyCircuitData, }; use crate::plonk::copy_constraint::CopyConstraint; +use crate::plonk::permutation_argument::Forest; use crate::plonk::plonk_common::PlonkPolynomials; use crate::polynomial::polynomial::PolynomialValues; use crate::util::context_tree::ContextTree; @@ -456,40 +456,37 @@ impl, const D: usize> CircuitBuilder { .collect() } - fn sigma_vecs( - &self, - k_is: &[F], - subgroup: &[F], - ) -> (Vec>, PartitionWitness) { + fn sigma_vecs(&self, k_is: &[F], subgroup: &[F]) -> (Vec>, Forest) { let degree = self.gate_instances.len(); let degree_log = log2_strict(degree); - let mut partition_witness = PartitionWitness::new( - self.config.num_wires, - self.config.num_routed_wires, + let config = &self.config; + let mut forest = Forest::new( + config.num_wires, + config.num_routed_wires, degree, self.virtual_target_index, ); for gate in 0..degree { - for input in 0..self.config.num_wires { - partition_witness.add(Target::Wire(Wire { gate, input })); + for input in 0..config.num_wires { + forest.add(Target::Wire(Wire { gate, input })); } } for index in 0..self.virtual_target_index { - partition_witness.add(Target::VirtualTarget { index }); + forest.add(Target::VirtualTarget { index }); } for &CopyConstraint { pair: (a, b), .. } in &self.copy_constraints { - partition_witness.merge(a, b); + forest.merge(a, b); } - partition_witness.compress_paths(); + forest.compress_paths(); - let wire_partition = partition_witness.wire_partition(); + let wire_partition = forest.wire_partition(); ( wire_partition.get_sigma_polys(degree_log, k_is, subgroup), - partition_witness, + forest, ) } @@ -607,7 +604,7 @@ impl, const D: usize> CircuitBuilder { let constant_vecs = self.constant_polys(&prefixed_gates, num_constants); let k_is = get_unique_coset_shifts(degree, self.config.num_routed_wires); - let (sigma_vecs, partition_witness) = self.sigma_vecs(&k_is, &subgroup); + let (sigma_vecs, forest) = self.sigma_vecs(&k_is, &subgroup); // Precompute FFT roots. let max_fft_points = @@ -633,8 +630,8 @@ impl, const D: usize> CircuitBuilder { let mut generator_indices_by_watches = BTreeMap::new(); for (i, generator) in self.generators.iter().enumerate() { for watch in generator.watch_list() { - let watch_index = partition_witness.target_index(watch); - let watch_rep_index = partition_witness.forest[watch_index].parent; + let watch_index = forest.target_index(watch); + let watch_rep_index = forest.parents[watch_index]; generator_indices_by_watches .entry(watch_rep_index) .or_insert(vec![]) @@ -654,7 +651,7 @@ impl, const D: usize> CircuitBuilder { subgroup, public_inputs: self.public_inputs, marked_targets: self.marked_targets, - partition_witness, + representative_map: forest.parents, fft_root_table: Some(fft_root_table), }; @@ -686,6 +683,7 @@ impl, const D: usize> CircuitBuilder { quotient_degree_factor, num_gate_constraints, num_constants, + num_virtual_targets: self.virtual_target_index, k_is, num_partial_products, circuit_digest, diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 82f1c1cd..644ab370 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -13,7 +13,7 @@ use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::generator::WitnessGenerator; use crate::iop::target::Target; -use crate::iop::witness::{PartialWitness, PartitionWitness}; +use crate::iop::witness::PartialWitness; use crate::plonk::proof::ProofWithPublicInputs; use crate::plonk::prover::prove; use crate::plonk::verifier::verify; @@ -156,8 +156,9 @@ pub(crate) struct ProverOnlyCircuitData, const D: u pub public_inputs: Vec, /// A vector of marked targets. The values assigned to these targets will be displayed by the prover. pub marked_targets: Vec>, - /// Partial witness holding the copy constraints information. - pub partition_witness: PartitionWitness, + /// A map from each `Target`'s index to the index of its representative in the disjoint-set + /// forest. + pub representative_map: Vec, /// Pre-computed roots for faster FFT. pub fft_root_table: Option>, } @@ -188,6 +189,8 @@ pub struct CommonCircuitData, const D: usize> { /// The number of constant wires. pub(crate) num_constants: usize, + pub(crate) num_virtual_targets: usize, + /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. pub(crate) k_is: Vec, diff --git a/src/plonk/permutation_argument.rs b/src/plonk/permutation_argument.rs index 761c1bd8..acf3686c 100644 --- a/src/plonk/permutation_argument.rs +++ b/src/plonk/permutation_argument.rs @@ -1,55 +1,55 @@ use std::collections::HashMap; -use std::fmt::Debug; use rayon::prelude::*; use crate::field::field_types::Field; use crate::iop::target::Target; use crate::iop::wire::Wire; -use crate::iop::witness::PartitionWitness; use crate::polynomial::polynomial::PolynomialValues; -/// Node in the Disjoint Set Forest. -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ForestNode { - pub parent: usize, - pub size: usize, - pub value: Option, +/// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure. +pub struct Forest { + /// A map of parent pointers, stored as indices. + pub(crate) parents: Vec, + + num_wires: usize, + num_routed_wires: usize, + degree: usize, } -/// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure. -impl PartitionWitness { +impl Forest { pub fn new( num_wires: usize, num_routed_wires: usize, degree: usize, num_virtual_targets: usize, ) -> Self { + let capacity = num_wires * degree + num_virtual_targets; Self { - forest: Vec::with_capacity(degree * num_wires + num_virtual_targets), + parents: Vec::with_capacity(capacity), num_wires, num_routed_wires, degree, } } + pub(crate) fn target_index(&self, target: Target) -> usize { + target.index(self.num_wires, self.degree) + } + /// Add a new partition with a single member. pub fn add(&mut self, t: Target) { - let index = self.forest.len(); + let index = self.parents.len(); debug_assert_eq!(self.target_index(t), index); - self.forest.push(ForestNode { - parent: index, - size: 1, - value: None, - }); + self.parents.push(index); } /// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives. pub fn find(&mut self, x_index: usize) -> usize { - let x = self.forest[x_index]; - if x.parent != x_index { - let root_index = self.find(x.parent); - self.forest[x_index].parent = root_index; + let x_parent = self.parents[x_index]; + if x_parent != x_index { + let root_index = self.find(x_parent); + self.parents[x_index] = root_index; root_index } else { x_index @@ -65,25 +65,13 @@ impl PartitionWitness { return; } - let mut x = self.forest[x_index]; - let mut y = self.forest[y_index]; - - if x.size >= y.size { - y.parent = x_index; - x.size += y.size; - } else { - x.parent = y_index; - y.size += x.size; - } - - self.forest[x_index] = x; - self.forest[y_index] = y; + self.parents[y_index] = x_index; } /// Compress all paths. After calling this, every `parent` value will point to the node's /// representative. pub(crate) fn compress_paths(&mut self) { - for i in 0..self.forest.len() { + for i in 0..self.parents.len() { self.find(i); } } @@ -96,8 +84,8 @@ impl PartitionWitness { for input in 0..self.num_routed_wires { let w = Wire { gate, input }; let t = Target::Wire(w); - let x = self.forest[self.target_index(t)]; - partition.entry(x.parent).or_default().push(w); + let x_parent = self.parents[self.target_index(t)]; + partition.entry(x_parent).or_default().push(w); } } diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 9f4303ae..19852cff 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -36,7 +36,7 @@ pub(crate) fn prove, const D: usize>( let partition_witness = timed!( timing, &format!("run {} generators", prover_data.generators.len()), - generate_partial_witness(inputs, &prover_data) + generate_partial_witness(inputs, prover_data, common_data) ); let public_inputs = partition_witness.get_targets(&prover_data.public_inputs);