diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 8a0f2652..3e884891 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -190,7 +190,7 @@ impl Witness for PartialWitness { /// The value of a target is defined to be the value of its root in the forest. #[derive(Clone)] pub struct PartitionWitness { - pub forest: Vec>, + pub forest: Vec>, pub num_wires: usize, pub num_routed_wires: usize, pub degree: usize, diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 415c011e..db9e855b 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -484,6 +484,8 @@ impl, const D: usize> CircuitBuilder { partition_witness.merge(a, b); } + partition_witness.compress_paths(); + let wire_partition = partition_witness.wire_partition(); ( wire_partition.get_sigma_polys(degree_log, k_is, subgroup), diff --git a/src/plonk/permutation_argument.rs b/src/plonk/permutation_argument.rs index 3d7d68b8..761c1bd8 100644 --- a/src/plonk/permutation_argument.rs +++ b/src/plonk/permutation_argument.rs @@ -11,11 +11,9 @@ use crate::polynomial::polynomial::PolynomialValues; /// Node in the Disjoint Set Forest. #[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ForestNode { - pub t: T, +pub struct ForestNode { pub parent: usize, pub size: usize, - pub index: usize, pub value: Option, } @@ -40,69 +38,70 @@ impl PartitionWitness { let index = self.forest.len(); debug_assert_eq!(self.target_index(t), index); self.forest.push(ForestNode { - t, parent: index, size: 1, - index, value: None, }); } /// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives. - pub fn find(&mut self, x: ForestNode) -> ForestNode { - if x.parent != x.index { - let root = self.find(self.forest[x.parent]); - self.forest[x.index].parent = root.index; - root + pub fn find(&mut self, x_index: usize) -> usize { + let x = self.forest[x_index]; + if x.parent != x_index { + let root_index = self.find(x.parent); + self.forest[x_index].parent = root_index; + root_index } else { - x + x_index } } /// Merge two sets. pub fn merge(&mut self, tx: Target, ty: Target) { - let mut x = self.forest[self.target_index(tx)]; - let mut y = self.forest[self.target_index(ty)]; + let x_index = self.find(self.target_index(tx)); + let y_index = self.find(self.target_index(ty)); - x = self.find(x); - y = self.find(y); - - if x == y { + if x_index == y_index { return; } + let mut x = self.forest[x_index]; + let mut y = self.forest[y_index]; + if x.size >= y.size { - y.parent = x.index; + y.parent = x_index; x.size += y.size; } else { - x.parent = y.index; + x.parent = y_index; y.size += x.size; } - self.forest[x.index] = x; - self.forest[y.index] = y; + self.forest[x_index] = x; + self.forest[y_index] = y; + } + + /// Compress all paths. After calling this, every `parent` value will point to the node's + /// representative. + pub(crate) fn compress_paths(&mut self) { + for i in 0..self.forest.len() { + self.find(i); + } } pub fn wire_partition(&mut self) -> WirePartition { let mut partition = HashMap::<_, Vec<_>>::new(); + + // Here we keep just the Wire targets, filtering out everything else. for gate in 0..self.degree { for input in 0..self.num_routed_wires { let w = Wire { gate, input }; let t = Target::Wire(w); let x = self.forest[self.target_index(t)]; - partition.entry(self.find(x).t).or_default().push(w); + partition.entry(x.parent).or_default().push(w); } } - // I'm not 100% sure this loop is needed, but I'm afraid removing it might lead to subtle bugs. - for index in 0..self.forest.len() - self.degree * self.num_wires { - let t = Target::VirtualTarget { index }; - let x = self.forest[self.target_index(t)]; - self.find(x); - } - - // Here we keep just the Wire targets, filtering out everything else. - let partition = partition.into_values().collect::>(); + let partition = partition.into_values().collect(); WirePartition { partition } } }