diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index a28d5121..93b440ab 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -16,7 +16,7 @@ use crate::gates::gate_tree::Tree; use crate::gates::noop::NoopGate; use crate::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator}; use crate::hash::hash_n_to_hash; -use crate::permutation_argument::TargetPartitions; +use crate::permutation_argument::TargetPartition; use crate::plonk_common::PlonkPolynomials; use crate::polynomial::commitment::ListPolynomialCommitment; use crate::polynomial::polynomial::PolynomialValues; @@ -360,28 +360,34 @@ impl, const D: usize> CircuitBuilder { fn sigma_vecs(&self, k_is: &[F], subgroup: &[F]) -> Vec> { let degree = self.gate_instances.len(); let degree_log = log2_strict(degree); - let mut target_partitions = TargetPartitions::new(); + let mut target_partition = TargetPartition::new(|t| match t { + Target::Wire(Wire { gate, input }) => gate * self.config.num_routed_wires + input, + Target::PublicInput { index } => degree * self.config.num_routed_wires + index, + Target::VirtualTarget { index } => { + degree * self.config.num_routed_wires + self.public_input_index + index + } + }); for gate in 0..degree { for input in 0..self.config.num_routed_wires { - target_partitions.add_partition(Target::Wire(Wire { gate, input })); + target_partition.add(Target::Wire(Wire { gate, input })); } } for index in 0..self.public_input_index { - target_partitions.add_partition(Target::PublicInput { index }); + target_partition.add(Target::PublicInput { index }); } for index in 0..self.virtual_target_index { - target_partitions.add_partition(Target::VirtualTarget { index }); + target_partition.add(Target::VirtualTarget { index }); } for &(a, b) in &self.copy_constraints { - target_partitions.merge(a, b); + target_partition.merge(a, b); } - let wire_partitions = target_partitions.to_wire_partitions(); - wire_partitions.get_sigma_polys(degree_log, k_is, subgroup) + let wire_partition = target_partition.wire_partition(); + wire_partition.get_sigma_polys(degree_log, k_is, subgroup) } /// Builds a "full circuit", with both prover and verifier data. diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index bdfade7c..9f617043 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -323,7 +323,7 @@ mod tests { use crate::gates::gmimc::{GMiMCGate, W}; use crate::generator::generate_partial_witness; use crate::gmimc::gmimc_permute_naive; - use crate::permutation_argument::TargetPartitions; + use crate::permutation_argument::TargetPartition; use crate::target::Target; use crate::wire::Wire; use crate::witness::PartialWitness; diff --git a/src/permutation_argument.rs b/src/permutation_argument.rs index cc202a95..37ecf8d2 100644 --- a/src/permutation_argument.rs +++ b/src/permutation_argument.rs @@ -1,4 +1,6 @@ use std::collections::HashMap; +use std::fmt::Debug; +use std::hash::Hash; use rayon::prelude::*; @@ -7,85 +9,111 @@ use crate::polynomial::polynomial::PolynomialValues; use crate::target::Target; use crate::wire::Wire; +/// Node in the Disjoint Set Forest. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct ForestNode { + t: T, + parent: usize, + size: usize, + index: usize, +} + +/// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure. #[derive(Debug, Clone)] -pub struct TargetPartitions { - partitions: Vec>, - indices: HashMap, +pub struct TargetPartition usize> { + forest: Vec>, + /// Function to compute a node's index in the forest. + indices: F, } -impl Default for TargetPartitions { - fn default() -> Self { - TargetPartitions::new() - } -} - -impl TargetPartitions { - pub fn new() -> Self { +impl usize> TargetPartition { + pub fn new(f: F) -> Self { Self { - partitions: Vec::new(), - indices: HashMap::new(), + forest: Vec::new(), + indices: f, } } - - pub fn get_partition(&self, target: Target) -> &[Target] { - &self.partitions[self.indices[&target]] - } - /// Add a new partition with a single member. - pub fn add_partition(&mut self, target: Target) { - let index = self.partitions.len(); - self.partitions.push(vec![target]); - self.indices.insert(target, index); + pub fn add(&mut self, t: T) { + let index = self.forest.len(); + debug_assert_eq!((self.indices)(t), index); + self.forest.push(ForestNode { + t, + parent: index, + size: 1, + index, + }); } - /// Merge the two partitions containing the two given targets. Does nothing if the targets are - /// already members of the same partition. - pub fn merge(&mut self, a: Target, b: Target) { - let a_index = self.indices[&a]; - let b_index = self.indices[&b]; - if a_index != b_index { - // Merge a's partition into b's partition, leaving a's partition empty. - // We have to clone because Rust's borrow checker doesn't know that - // self.partitions[b_index] and self.partitions[b_index] are disjoint. - let mut a_partition = self.partitions[a_index].clone(); - let b_partition = &mut self.partitions[b_index]; - for a_sibling in &a_partition { - *self.indices.get_mut(a_sibling).unwrap() = b_index; - } - b_partition.append(&mut a_partition); + /// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives. + pub fn find(&mut self, mut x: ForestNode) -> ForestNode { + if x.parent != x.index { + let root = self.find(self.forest[x.parent]); + self.forest[x.index].parent = root.index; + root + } else { + x } } - pub fn to_wire_partitions(&self) -> WirePartitions { - // Here we keep just the Wire targets, filtering out everything else. - let mut partitions = Vec::new(); + /// Merge two sets. + pub fn merge(&mut self, tx: T, ty: T) { + let mut x = self.forest[(self.indices)(tx)]; + let mut y = self.forest[(self.indices)(ty)]; + + x = self.find(x); + y = self.find(y); + + if x == y { + return; + } + + if x.size >= y.size { + y.parent = x.index; + x.size += y.size; + } else { + x.parent = y.index; + y.size += x.size; + } + + self.forest[x.index] = x; + self.forest[y.index] = y; + } +} +impl usize> TargetPartition { + pub fn wire_partition(&mut self) -> WirePartitions { + let mut partition = HashMap::<_, Vec<_>>::new(); + let nodes = self.forest.clone(); + for x in nodes { + let v = partition.entry(self.find(x).t).or_default(); + v.push(x.t); + } + let mut indices = HashMap::new(); + // // Here we keep just the Wire targets, filtering out everything else. + let partition = partition + .into_values() + .map(|v| { + v.into_iter() + .filter_map(|t| match t { + Target::Wire(w) => Some(w), + _ => None, + }) + .collect::>() + }) + .collect::>(); + partition.iter().enumerate().for_each(|(i, v)| { + v.iter().for_each(|t| { + indices.insert(*t, i); + }); + }); - for old_partition in &self.partitions { - let mut new_partition = Vec::new(); - for target in old_partition { - if let Target::Wire(w) = *target { - new_partition.push(w); - } - } - partitions.push(new_partition); - } - - for (&target, &index) in &self.indices { - if let Target::Wire(gi) = target { - indices.insert(gi, index); - } - } - - WirePartitions { - partitions, - indices, - } + WirePartitions { partition, indices } } } pub struct WirePartitions { - partitions: Vec>, + partition: Vec>, indices: HashMap, } @@ -95,7 +123,7 @@ impl WirePartitions { /// its partition, this will loop around. If the given wire has a partition all to itself, it /// is considered its own neighbor. fn get_neighbor(&self, wire: Wire) -> Wire { - let partition = &self.partitions[self.indices[&wire]]; + let partition = &self.partition[self.indices[&wire]]; let n = partition.len(); for i in 0..n { if partition[i] == wire { diff --git a/src/plonk_challenger.rs b/src/plonk_challenger.rs index 287eaa07..149dceb3 100644 --- a/src/plonk_challenger.rs +++ b/src/plonk_challenger.rs @@ -323,7 +323,7 @@ mod tests { use crate::field::crandall_field::CrandallField; use crate::field::field::Field; use crate::generator::generate_partial_witness; - use crate::permutation_argument::TargetPartitions; + use crate::permutation_argument::TargetPartition; use crate::plonk_challenger::{Challenger, RecursiveChallenger}; use crate::target::Target; use crate::witness::PartialWitness;