From fc0f8a78ce928cebe287203ad48199b229e78001 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 2 Jul 2021 14:13:57 +0200 Subject: [PATCH 1/7] First try --- src/circuit_builder.rs | 14 +-- src/gates/gmimc.rs | 2 +- src/permutation_argument.rs | 198 +++++++++++++++++++++++++----------- src/plonk_challenger.rs | 2 +- 4 files changed, 149 insertions(+), 67 deletions(-) diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index fd786dfc..8caecc25 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -16,7 +16,7 @@ use crate::gates::gate_tree::Tree; use crate::gates::noop::NoopGate; use crate::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator}; use crate::hash::hash_n_to_hash; -use crate::permutation_argument::TargetPartitions; +use crate::permutation_argument::TargetPartition; use crate::plonk_common::PlonkPolynomials; use crate::polynomial::commitment::ListPolynomialCommitment; use crate::polynomial::polynomial::PolynomialValues; @@ -359,27 +359,27 @@ impl, const D: usize> CircuitBuilder { fn sigma_vecs(&self, k_is: &[F], subgroup: &[F]) -> Vec> { let degree = self.gate_instances.len(); let degree_log = log2_strict(degree); - let mut target_partitions = TargetPartitions::new(); + let mut target_partition = TargetPartition::default(); for gate in 0..degree { for input in 0..self.config.num_routed_wires { - target_partitions.add_partition(Target::Wire(Wire { gate, input })); + target_partition.add(Target::Wire(Wire { gate, input })); } } for index in 0..self.public_input_index { - target_partitions.add_partition(Target::PublicInput { index }); + target_partition.add(Target::PublicInput { index }); } for index in 0..self.virtual_target_index { - target_partitions.add_partition(Target::VirtualTarget { index }); + target_partition.add(Target::VirtualTarget { index }); } for &(a, b) in &self.copy_constraints { - target_partitions.merge(a, b); + target_partition.merge(a, b); } - let wire_partitions = target_partitions.to_wire_partitions(); + let wire_partitions = target_partition.wire_partitions(); wire_partitions.get_sigma_polys(degree_log, k_is, subgroup) } diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index bdfade7c..9f617043 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -323,7 +323,7 @@ mod tests { use crate::gates::gmimc::{GMiMCGate, W}; use crate::generator::generate_partial_witness; use crate::gmimc::gmimc_permute_naive; - use crate::permutation_argument::TargetPartitions; + use crate::permutation_argument::TargetPartition; use crate::target::Target; use crate::wire::Wire; use crate::witness::PartialWitness; diff --git a/src/permutation_argument.rs b/src/permutation_argument.rs index cc202a95..003400de 100644 --- a/src/permutation_argument.rs +++ b/src/permutation_argument.rs @@ -1,4 +1,6 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; +use std::hash::Hash; use rayon::prelude::*; @@ -7,85 +9,153 @@ use crate::polynomial::polynomial::PolynomialValues; use crate::target::Target; use crate::wire::Wire; -#[derive(Debug, Clone)] -pub struct TargetPartitions { - partitions: Vec>, - indices: HashMap, +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct ForestNode { + t: T, + parent: usize, + size: usize, + index: usize, } -impl Default for TargetPartitions { - fn default() -> Self { - TargetPartitions::new() - } -} - -impl TargetPartitions { - pub fn new() -> Self { +impl ForestNode { + pub fn new(t: T, parent: usize, size: usize, index: usize) -> Self { Self { - partitions: Vec::new(), - indices: HashMap::new(), + t, + parent, + size, + index, } } +} - pub fn get_partition(&self, target: Target) -> &[Target] { - &self.partitions[self.indices[&target]] +#[derive(Debug, Clone)] +pub struct TargetPartition { + forest: Vec>, + indices: HashMap, +} + +impl Default for TargetPartition { + fn default() -> Self { + Self { + forest: Vec::new(), + indices: Default::default(), + } + } +} + +impl TargetPartition { + pub fn get(&self, t: T) -> ForestNode { + self.forest[self.indices[&t]] } + pub fn get_mut(&mut self, t: T) -> &mut ForestNode { + &mut self.forest[self.indices[&t]] + } + // pub fn get_partition(&self, target: Target) -> &[Target] { + // &self.partitions[self.indices[&target]] + // } + /// Add a new partition with a single member. - pub fn add_partition(&mut self, target: Target) { - let index = self.partitions.len(); - self.partitions.push(vec![target]); - self.indices.insert(target, index); + pub fn add(&mut self, t: T) { + let index = self.forest.len(); + self.forest.push(ForestNode::new(t, index, 1, index)); + self.indices.insert(t, index); + } + + /// Path halving + pub fn find(&mut self, mut x: ForestNode) -> ForestNode { + while x.parent != x.index { + let grandparent = self.forest[x.parent].parent; + x.parent = grandparent; + x = self.forest[grandparent]; + } + x } /// Merge the two partitions containing the two given targets. Does nothing if the targets are /// already members of the same partition. - pub fn merge(&mut self, a: Target, b: Target) { - let a_index = self.indices[&a]; - let b_index = self.indices[&b]; - if a_index != b_index { - // Merge a's partition into b's partition, leaving a's partition empty. - // We have to clone because Rust's borrow checker doesn't know that - // self.partitions[b_index] and self.partitions[b_index] are disjoint. - let mut a_partition = self.partitions[a_index].clone(); - let b_partition = &mut self.partitions[b_index]; - for a_sibling in &a_partition { - *self.indices.get_mut(a_sibling).unwrap() = b_index; - } - b_partition.append(&mut a_partition); + pub fn merge(&mut self, tx: T, ty: T) { + let mut x = self.get(tx); + let mut y = self.get(ty); + let index_x = x.index; + let index_y = y.index; + + x = self.forest[x.parent]; + y = self.forest[y.parent]; + + if x == y { + return; } + + if x.size < y.size { + std::mem::swap(&mut x, &mut y); + } + + y.parent = x.index; + x.size += y.size; + + self.forest[index_x] = x; + self.forest[index_y] = y; } +} +impl TargetPartition { + pub fn wire_partitions(&mut self) -> WirePartitions { + let mut partition = HashMap::<_, Vec<_>>::new(); + let nodes = self.forest.clone(); + for x in nodes { + let v = partition.entry(self.find(x).t).or_default(); + v.push(x.t); + } - pub fn to_wire_partitions(&self) -> WirePartitions { - // Here we keep just the Wire targets, filtering out everything else. - let mut partitions = Vec::new(); let mut indices = HashMap::new(); + let partition = partition + .into_values() + .map(|v| { + v.into_iter() + .filter_map(|t| match t { + Target::Wire(w) => Some(w), + _ => None, + }) + .collect::>() + }) + .collect::>(); + partition.iter().enumerate().for_each(|(i, v)| { + v.iter().for_each(|t| { + indices.insert(*t, i); + }); + }); - for old_partition in &self.partitions { - let mut new_partition = Vec::new(); - for target in old_partition { - if let Target::Wire(w) = *target { - new_partition.push(w); - } - } - partitions.push(new_partition); - } + WirePartitions { partition, indices } - for (&target, &index) in &self.indices { - if let Target::Wire(gi) = target { - indices.insert(gi, index); - } - } - - WirePartitions { - partitions, - indices, - } + // // Here we keep just the Wire targets, filtering out everything else. + // let mut partitions = Vec::new(); + // let mut indices = HashMap::new(); + // + // for old_partition in &self.partitions { + // let mut new_partition = Vec::new(); + // for target in old_partition { + // if let Target::Wire(w) = *target { + // new_partition.push(w); + // } + // } + // partitions.push(new_partition); + // } + // + // for (&target, &index) in &self.indices { + // if let Target::Wire(gi) = target { + // indices.insert(gi, index); + // } + // } + // + // WirePartitions { + // partitions, + // indices, + // } } } pub struct WirePartitions { - partitions: Vec>, + partition: Vec>, indices: HashMap, } @@ -95,7 +165,7 @@ impl WirePartitions { /// its partition, this will loop around. If the given wire has a partition all to itself, it /// is considered its own neighbor. fn get_neighbor(&self, wire: Wire) -> Wire { - let partition = &self.partitions[self.indices[&wire]]; + let partition = &self.partition[self.indices[&wire]]; let n = partition.len(); for i in 0..n { if partition[i] == wire { @@ -144,3 +214,15 @@ impl WirePartitions { sigma } } + +#[test] +fn test_part() { + let mut part = TargetPartition::default(); + part.add(1); + part.add(2); + part.add(3); + + part.merge(1, 3); + + dbg!(part); +} diff --git a/src/plonk_challenger.rs b/src/plonk_challenger.rs index 9af5e590..f9025ee8 100644 --- a/src/plonk_challenger.rs +++ b/src/plonk_challenger.rs @@ -321,7 +321,7 @@ mod tests { use crate::field::crandall_field::CrandallField; use crate::field::field::Field; use crate::generator::generate_partial_witness; - use crate::permutation_argument::TargetPartitions; + use crate::permutation_argument::TargetPartition; use crate::plonk_challenger::{Challenger, RecursiveChallenger}; use crate::target::Target; use crate::witness::PartialWitness; From d93cf693ba2c04c2bbbaf4c00fe543ed45b78491 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 2 Jul 2021 14:26:49 +0200 Subject: [PATCH 2/7] Minor --- src/permutation_argument.rs | 75 ++++++------------------------------- 1 file changed, 11 insertions(+), 64 deletions(-) diff --git a/src/permutation_argument.rs b/src/permutation_argument.rs index 003400de..6a1838a6 100644 --- a/src/permutation_argument.rs +++ b/src/permutation_argument.rs @@ -17,17 +17,6 @@ pub struct ForestNode { index: usize, } -impl ForestNode { - pub fn new(t: T, parent: usize, size: usize, index: usize) -> Self { - Self { - t, - parent, - size, - index, - } - } -} - #[derive(Debug, Clone)] pub struct TargetPartition { forest: Vec>, @@ -44,21 +33,15 @@ impl Default for TargetPartition { } impl TargetPartition { - pub fn get(&self, t: T) -> ForestNode { - self.forest[self.indices[&t]] - } - - pub fn get_mut(&mut self, t: T) -> &mut ForestNode { - &mut self.forest[self.indices[&t]] - } - // pub fn get_partition(&self, target: Target) -> &[Target] { - // &self.partitions[self.indices[&target]] - // } - /// Add a new partition with a single member. pub fn add(&mut self, t: T) { let index = self.forest.len(); - self.forest.push(ForestNode::new(t, index, 1, index)); + self.forest.push(ForestNode { + t, + parent: index, + size: 1, + index, + }); self.indices.insert(t, index); } @@ -75,10 +58,10 @@ impl TargetPartition { /// Merge the two partitions containing the two given targets. Does nothing if the targets are /// already members of the same partition. pub fn merge(&mut self, tx: T, ty: T) { - let mut x = self.get(tx); - let mut y = self.get(ty); - let index_x = x.index; - let index_y = y.index; + let index_x = self.indices[&tx]; + let index_y = self.indices[&ty]; + let mut x = self.forest[index_x]; + let mut y = self.forest[index_y]; x = self.forest[x.parent]; y = self.forest[y.parent]; @@ -108,6 +91,7 @@ impl TargetPartition { } let mut indices = HashMap::new(); + // // Here we keep just the Wire targets, filtering out everything else. let partition = partition .into_values() .map(|v| { @@ -126,31 +110,6 @@ impl TargetPartition { }); WirePartitions { partition, indices } - - // // Here we keep just the Wire targets, filtering out everything else. - // let mut partitions = Vec::new(); - // let mut indices = HashMap::new(); - // - // for old_partition in &self.partitions { - // let mut new_partition = Vec::new(); - // for target in old_partition { - // if let Target::Wire(w) = *target { - // new_partition.push(w); - // } - // } - // partitions.push(new_partition); - // } - // - // for (&target, &index) in &self.indices { - // if let Target::Wire(gi) = target { - // indices.insert(gi, index); - // } - // } - // - // WirePartitions { - // partitions, - // indices, - // } } } @@ -214,15 +173,3 @@ impl WirePartitions { sigma } } - -#[test] -fn test_part() { - let mut part = TargetPartition::default(); - part.add(1); - part.add(2); - part.add(3); - - part.merge(1, 3); - - dbg!(part); -} From b6554ba2ecb4ec7af1fc15c17cb299db061e8c3b Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 2 Jul 2021 14:37:07 +0200 Subject: [PATCH 3/7] Replace `indices: HashMap` with `indices: Fn(T)->usize` --- src/circuit_builder.rs | 8 +++++++- src/permutation_argument.rs | 20 ++++++++------------ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 8caecc25..4990cffe 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -359,7 +359,13 @@ impl, const D: usize> CircuitBuilder { fn sigma_vecs(&self, k_is: &[F], subgroup: &[F]) -> Vec> { let degree = self.gate_instances.len(); let degree_log = log2_strict(degree); - let mut target_partition = TargetPartition::default(); + let mut target_partition = TargetPartition::new(|t| match t { + Target::Wire(Wire { gate, input }) => gate * self.config.num_routed_wires + input, + Target::PublicInput { index } => degree * self.config.num_routed_wires + index, + Target::VirtualTarget { index } => { + degree * self.config.num_routed_wires + self.public_input_index + index + } + }); for gate in 0..degree { for input in 0..self.config.num_routed_wires { diff --git a/src/permutation_argument.rs b/src/permutation_argument.rs index 6a1838a6..f7ee9a5e 100644 --- a/src/permutation_argument.rs +++ b/src/permutation_argument.rs @@ -18,21 +18,18 @@ pub struct ForestNode { } #[derive(Debug, Clone)] -pub struct TargetPartition { +pub struct TargetPartition usize> { forest: Vec>, - indices: HashMap, + indices: F, } -impl Default for TargetPartition { - fn default() -> Self { +impl usize> TargetPartition { + pub fn new(f: F) -> Self { Self { forest: Vec::new(), - indices: Default::default(), + indices: f, } } -} - -impl TargetPartition { /// Add a new partition with a single member. pub fn add(&mut self, t: T) { let index = self.forest.len(); @@ -42,7 +39,6 @@ impl TargetPartition { size: 1, index, }); - self.indices.insert(t, index); } /// Path halving @@ -58,8 +54,8 @@ impl TargetPartition { /// Merge the two partitions containing the two given targets. Does nothing if the targets are /// already members of the same partition. pub fn merge(&mut self, tx: T, ty: T) { - let index_x = self.indices[&tx]; - let index_y = self.indices[&ty]; + let index_x = (self.indices)(tx); + let index_y = (self.indices)(ty); let mut x = self.forest[index_x]; let mut y = self.forest[index_y]; @@ -81,7 +77,7 @@ impl TargetPartition { self.forest[index_y] = y; } } -impl TargetPartition { +impl usize> TargetPartition { pub fn wire_partitions(&mut self) -> WirePartitions { let mut partition = HashMap::<_, Vec<_>>::new(); let nodes = self.forest.clone(); From 13f470e47d40a1342092be870cf1039ec23aa35c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 2 Jul 2021 14:42:40 +0200 Subject: [PATCH 4/7] Comments --- src/permutation_argument.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/permutation_argument.rs b/src/permutation_argument.rs index f7ee9a5e..12433c8a 100644 --- a/src/permutation_argument.rs +++ b/src/permutation_argument.rs @@ -9,6 +9,7 @@ use crate::polynomial::polynomial::PolynomialValues; use crate::target::Target; use crate::wire::Wire; +/// Node in the Disjoint Set Forest. #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct ForestNode { t: T, @@ -17,9 +18,11 @@ pub struct ForestNode { index: usize, } +/// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure. #[derive(Debug, Clone)] pub struct TargetPartition usize> { forest: Vec>, + /// Function to compute a node's index in the forest. indices: F, } @@ -33,6 +36,7 @@ impl usize> TargetPartition /// Add a new partition with a single member. pub fn add(&mut self, t: T) { let index = self.forest.len(); + debug_assert_eq!((self.indices)(t), index); self.forest.push(ForestNode { t, parent: index, @@ -41,7 +45,7 @@ impl usize> TargetPartition }); } - /// Path halving + /// Path halving method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives. pub fn find(&mut self, mut x: ForestNode) -> ForestNode { while x.parent != x.index { let grandparent = self.forest[x.parent].parent; @@ -51,8 +55,7 @@ impl usize> TargetPartition x } - /// Merge the two partitions containing the two given targets. Does nothing if the targets are - /// already members of the same partition. + /// Merge two sets. pub fn merge(&mut self, tx: T, ty: T) { let index_x = (self.indices)(tx); let index_y = (self.indices)(ty); From 73c1733e6aed3125ff2c49c9125b7cb994a5285f Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 2 Jul 2021 14:45:05 +0200 Subject: [PATCH 5/7] Clippy --- src/circuit_builder.rs | 4 ++-- src/permutation_argument.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 4990cffe..a83a4977 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -385,8 +385,8 @@ impl, const D: usize> CircuitBuilder { target_partition.merge(a, b); } - let wire_partitions = target_partition.wire_partitions(); - wire_partitions.get_sigma_polys(degree_log, k_is, subgroup) + let wire_partition = target_partition.wire_partition(); + wire_partition.get_sigma_polys(degree_log, k_is, subgroup) } /// Builds a "full circuit", with both prover and verifier data. diff --git a/src/permutation_argument.rs b/src/permutation_argument.rs index 12433c8a..e6fc07cf 100644 --- a/src/permutation_argument.rs +++ b/src/permutation_argument.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::fmt::Debug; use std::hash::Hash; @@ -81,7 +81,7 @@ impl usize> TargetPartition } } impl usize> TargetPartition { - pub fn wire_partitions(&mut self) -> WirePartitions { + pub fn wire_partition(&mut self) -> WirePartitions { let mut partition = HashMap::<_, Vec<_>>::new(); let nodes = self.forest.clone(); for x in nodes { From b7561c31a2b0540e22807c92a28c798123e865fd Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 2 Jul 2021 15:34:23 +0200 Subject: [PATCH 6/7] Fix bugs --- src/permutation_argument.rs | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/permutation_argument.rs b/src/permutation_argument.rs index e6fc07cf..a8b47819 100644 --- a/src/permutation_argument.rs +++ b/src/permutation_argument.rs @@ -49,7 +49,7 @@ impl usize> TargetPartition pub fn find(&mut self, mut x: ForestNode) -> ForestNode { while x.parent != x.index { let grandparent = self.forest[x.parent].parent; - x.parent = grandparent; + self.forest[x.index].parent = grandparent; x = self.forest[grandparent]; } x @@ -57,27 +57,26 @@ impl usize> TargetPartition /// Merge two sets. pub fn merge(&mut self, tx: T, ty: T) { - let index_x = (self.indices)(tx); - let index_y = (self.indices)(ty); - let mut x = self.forest[index_x]; - let mut y = self.forest[index_y]; + let mut x = self.forest[(self.indices)(tx)]; + let mut y = self.forest[(self.indices)(ty)]; - x = self.forest[x.parent]; - y = self.forest[y.parent]; + x = self.find(x); + y = self.find(y); if x == y { return; } - if x.size < y.size { - std::mem::swap(&mut x, &mut y); + if x.size >= y.size { + y.parent = x.index; + x.size += y.size; + } else { + x.parent = y.index; + y.size += x.size; } - y.parent = x.index; - x.size += y.size; - - self.forest[index_x] = x; - self.forest[index_y] = y; + self.forest[x.index] = x; + self.forest[y.index] = y; } } impl usize> TargetPartition { From 083d84139709101ba65840bdaec3e7e14e4f85ea Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 2 Jul 2021 15:44:50 +0200 Subject: [PATCH 7/7] Path halving -> Path compression --- src/permutation_argument.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/permutation_argument.rs b/src/permutation_argument.rs index a8b47819..37ecf8d2 100644 --- a/src/permutation_argument.rs +++ b/src/permutation_argument.rs @@ -45,14 +45,15 @@ impl usize> TargetPartition }); } - /// Path halving method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives. + /// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives. pub fn find(&mut self, mut x: ForestNode) -> ForestNode { - while x.parent != x.index { - let grandparent = self.forest[x.parent].parent; - self.forest[x.index].parent = grandparent; - x = self.forest[grandparent]; + if x.parent != x.index { + let root = self.find(self.forest[x.parent]); + self.forest[x.index].parent = root.index; + root + } else { + x } - x } /// Merge two sets.