From fc0f8a78ce928cebe287203ad48199b229e78001 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 2 Jul 2021 14:13:57 +0200 Subject: [PATCH] First try --- src/circuit_builder.rs | 14 +-- src/gates/gmimc.rs | 2 +- src/permutation_argument.rs | 198 +++++++++++++++++++++++++----------- src/plonk_challenger.rs | 2 +- 4 files changed, 149 insertions(+), 67 deletions(-) diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index fd786dfc..8caecc25 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -16,7 +16,7 @@ use crate::gates::gate_tree::Tree; use crate::gates::noop::NoopGate; use crate::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator}; use crate::hash::hash_n_to_hash; -use crate::permutation_argument::TargetPartitions; +use crate::permutation_argument::TargetPartition; use crate::plonk_common::PlonkPolynomials; use crate::polynomial::commitment::ListPolynomialCommitment; use crate::polynomial::polynomial::PolynomialValues; @@ -359,27 +359,27 @@ impl, const D: usize> CircuitBuilder { fn sigma_vecs(&self, k_is: &[F], subgroup: &[F]) -> Vec> { let degree = self.gate_instances.len(); let degree_log = log2_strict(degree); - let mut target_partitions = TargetPartitions::new(); + let mut target_partition = TargetPartition::default(); for gate in 0..degree { for input in 0..self.config.num_routed_wires { - target_partitions.add_partition(Target::Wire(Wire { gate, input })); + target_partition.add(Target::Wire(Wire { gate, input })); } } for index in 0..self.public_input_index { - target_partitions.add_partition(Target::PublicInput { index }); + target_partition.add(Target::PublicInput { index }); } for index in 0..self.virtual_target_index { - target_partitions.add_partition(Target::VirtualTarget { index }); + target_partition.add(Target::VirtualTarget { index }); } for &(a, b) in &self.copy_constraints { - target_partitions.merge(a, b); + target_partition.merge(a, b); } - let wire_partitions = target_partitions.to_wire_partitions(); + let wire_partitions = target_partition.wire_partitions(); wire_partitions.get_sigma_polys(degree_log, k_is, subgroup) } diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index bdfade7c..9f617043 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -323,7 +323,7 @@ mod tests { use crate::gates::gmimc::{GMiMCGate, W}; use crate::generator::generate_partial_witness; use crate::gmimc::gmimc_permute_naive; - use crate::permutation_argument::TargetPartitions; + use crate::permutation_argument::TargetPartition; use crate::target::Target; use crate::wire::Wire; use crate::witness::PartialWitness; diff --git a/src/permutation_argument.rs b/src/permutation_argument.rs index cc202a95..003400de 100644 --- a/src/permutation_argument.rs +++ b/src/permutation_argument.rs @@ -1,4 +1,6 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; +use std::hash::Hash; use rayon::prelude::*; @@ -7,85 +9,153 @@ use crate::polynomial::polynomial::PolynomialValues; use crate::target::Target; use crate::wire::Wire; -#[derive(Debug, Clone)] -pub struct TargetPartitions { - partitions: Vec>, - indices: HashMap, +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct ForestNode { + t: T, + parent: usize, + size: usize, + index: usize, } -impl Default for TargetPartitions { - fn default() -> Self { - TargetPartitions::new() - } -} - -impl TargetPartitions { - pub fn new() -> Self { +impl ForestNode { + pub fn new(t: T, parent: usize, size: usize, index: usize) -> Self { Self { - partitions: Vec::new(), - indices: HashMap::new(), + t, + parent, + size, + index, } } +} - pub fn get_partition(&self, target: Target) -> &[Target] { - &self.partitions[self.indices[&target]] +#[derive(Debug, Clone)] +pub struct TargetPartition { + forest: Vec>, + indices: HashMap, +} + +impl Default for TargetPartition { + fn default() -> Self { + Self { + forest: Vec::new(), + indices: Default::default(), + } + } +} + +impl TargetPartition { + pub fn get(&self, t: T) -> ForestNode { + self.forest[self.indices[&t]] } + pub fn get_mut(&mut self, t: T) -> &mut ForestNode { + &mut self.forest[self.indices[&t]] + } + // pub fn get_partition(&self, target: Target) -> &[Target] { + // &self.partitions[self.indices[&target]] + // } + /// Add a new partition with a single member. - pub fn add_partition(&mut self, target: Target) { - let index = self.partitions.len(); - self.partitions.push(vec![target]); - self.indices.insert(target, index); + pub fn add(&mut self, t: T) { + let index = self.forest.len(); + self.forest.push(ForestNode::new(t, index, 1, index)); + self.indices.insert(t, index); + } + + /// Path halving + pub fn find(&mut self, mut x: ForestNode) -> ForestNode { + while x.parent != x.index { + let grandparent = self.forest[x.parent].parent; + x.parent = grandparent; + x = self.forest[grandparent]; + } + x } /// Merge the two partitions containing the two given targets. Does nothing if the targets are /// already members of the same partition. - pub fn merge(&mut self, a: Target, b: Target) { - let a_index = self.indices[&a]; - let b_index = self.indices[&b]; - if a_index != b_index { - // Merge a's partition into b's partition, leaving a's partition empty. - // We have to clone because Rust's borrow checker doesn't know that - // self.partitions[b_index] and self.partitions[b_index] are disjoint. - let mut a_partition = self.partitions[a_index].clone(); - let b_partition = &mut self.partitions[b_index]; - for a_sibling in &a_partition { - *self.indices.get_mut(a_sibling).unwrap() = b_index; - } - b_partition.append(&mut a_partition); + pub fn merge(&mut self, tx: T, ty: T) { + let mut x = self.get(tx); + let mut y = self.get(ty); + let index_x = x.index; + let index_y = y.index; + + x = self.forest[x.parent]; + y = self.forest[y.parent]; + + if x == y { + return; } + + if x.size < y.size { + std::mem::swap(&mut x, &mut y); + } + + y.parent = x.index; + x.size += y.size; + + self.forest[index_x] = x; + self.forest[index_y] = y; } +} +impl TargetPartition { + pub fn wire_partitions(&mut self) -> WirePartitions { + let mut partition = HashMap::<_, Vec<_>>::new(); + let nodes = self.forest.clone(); + for x in nodes { + let v = partition.entry(self.find(x).t).or_default(); + v.push(x.t); + } - pub fn to_wire_partitions(&self) -> WirePartitions { - // Here we keep just the Wire targets, filtering out everything else. - let mut partitions = Vec::new(); let mut indices = HashMap::new(); + let partition = partition + .into_values() + .map(|v| { + v.into_iter() + .filter_map(|t| match t { + Target::Wire(w) => Some(w), + _ => None, + }) + .collect::>() + }) + .collect::>(); + partition.iter().enumerate().for_each(|(i, v)| { + v.iter().for_each(|t| { + indices.insert(*t, i); + }); + }); - for old_partition in &self.partitions { - let mut new_partition = Vec::new(); - for target in old_partition { - if let Target::Wire(w) = *target { - new_partition.push(w); - } - } - partitions.push(new_partition); - } + WirePartitions { partition, indices } - for (&target, &index) in &self.indices { - if let Target::Wire(gi) = target { - indices.insert(gi, index); - } - } - - WirePartitions { - partitions, - indices, - } + // // Here we keep just the Wire targets, filtering out everything else. + // let mut partitions = Vec::new(); + // let mut indices = HashMap::new(); + // + // for old_partition in &self.partitions { + // let mut new_partition = Vec::new(); + // for target in old_partition { + // if let Target::Wire(w) = *target { + // new_partition.push(w); + // } + // } + // partitions.push(new_partition); + // } + // + // for (&target, &index) in &self.indices { + // if let Target::Wire(gi) = target { + // indices.insert(gi, index); + // } + // } + // + // WirePartitions { + // partitions, + // indices, + // } } } pub struct WirePartitions { - partitions: Vec>, + partition: Vec>, indices: HashMap, } @@ -95,7 +165,7 @@ impl WirePartitions { /// its partition, this will loop around. If the given wire has a partition all to itself, it /// is considered its own neighbor. fn get_neighbor(&self, wire: Wire) -> Wire { - let partition = &self.partitions[self.indices[&wire]]; + let partition = &self.partition[self.indices[&wire]]; let n = partition.len(); for i in 0..n { if partition[i] == wire { @@ -144,3 +214,15 @@ impl WirePartitions { sigma } } + +#[test] +fn test_part() { + let mut part = TargetPartition::default(); + part.add(1); + part.add(2); + part.add(3); + + part.merge(1, 3); + + dbg!(part); +} diff --git a/src/plonk_challenger.rs b/src/plonk_challenger.rs index 9af5e590..f9025ee8 100644 --- a/src/plonk_challenger.rs +++ b/src/plonk_challenger.rs @@ -321,7 +321,7 @@ mod tests { use crate::field::crandall_field::CrandallField; use crate::field::field::Field; use crate::generator::generate_partial_witness; - use crate::permutation_argument::TargetPartitions; + use crate::permutation_argument::TargetPartition; use crate::plonk_challenger::{Challenger, RecursiveChallenger}; use crate::target::Target; use crate::witness::PartialWitness;