diff --git a/src/iop/generator.rs b/src/iop/generator.rs index b6d89c5d..4aa5d6e1 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -30,13 +30,13 @@ pub(crate) fn generate_partial_witness( // Target::VirtualTarget { index } => degree * num_wires + index, // } // }; - let max_target_index = witness.0.len(); + let max_target_index = witness.nodes.len(); // Index generator indices by their watched targets. let mut generator_indices_by_watches = vec![Vec::new(); max_target_index]; timed!(timing, "index generators by their watched targets", { for (i, generator) in generators.iter().enumerate() { for watch in generator.watch_list() { - generator_indices_by_watches[witness.1(watch)].push(i); + generator_indices_by_watches[witness.target_index(watch)].push(i); } } }); @@ -71,7 +71,9 @@ pub(crate) fn generate_partial_witness( // Enqueue unfinished generators that were watching one of the newly populated targets. for &(watch, _) in &buffer.target_values { - for &watching_generator_idx in &generator_indices_by_watches[witness.1(watch)] { + for &watching_generator_idx in + &generator_indices_by_watches[witness.target_index(watch)] + { if !generator_is_expired[watching_generator_idx] { next_pending_generator_indices.push(watching_generator_idx); } diff --git a/src/iop/witness.rs b/src/iop/witness.rs index e65f8c40..8241c7d5 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -217,32 +217,50 @@ impl Witness for PartialWitness { } } -pub struct PartitionWitness( - pub Vec>, - pub Box usize>, -); +#[derive(Clone)] +pub struct PartitionWitness { + pub nodes: Vec>, + pub num_wires: usize, + pub num_routed_wires: usize, + pub degree: usize, +} impl Witness for PartitionWitness { fn try_get_target(&self, target: Target) -> Option { - self.0[self.0[self.1(target)].parent].value + self.nodes[self.nodes[self.target_index(target)].parent].value } fn set_target(&mut self, target: Target, value: F) { - let i = self.0[self.1(target)].parent; - self.0[i].value = Some(value); + let i = self.nodes[self.target_index(target)].parent; + self.nodes[i].value = Some(value); } } impl PartitionWitness { - pub fn full_witness(self, degree: usize, num_wires: usize) -> MatrixWitness { - let mut wire_values = vec![vec![F::ZERO; degree]; num_wires]; - // assert!(self.wire_values.len() <= degree); - for i in 0..degree { - for j in 0..num_wires { - let t = Target::Wire(Wire { gate: i, input: j }); - wire_values[j][i] = self.0[self.0[self.1(t)].parent].value.unwrap_or(F::ZERO); + pub const fn target_index(&self, target: Target) -> usize { + match target { + Target::Wire(Wire { gate, input }) => gate * self.num_wires + input, + Target::VirtualTarget { index } => self.degree * self.num_wires + index, + } + } + + pub fn full_witness(self) -> MatrixWitness { + let mut wire_values = vec![vec![]; self.num_wires]; + for j in 0..self.num_wires { + wire_values[j].reserve_exact(self.degree); + unsafe { + // After .reserve_exact(l), wire_values[i] will have capacity at least l. Hence, set_len + // will not cause the buffer to overrun. + wire_values[j].set_len(self.degree); } } + for i in 0..self.degree { + for j in 0..self.num_wires { + let t = Target::Wire(Wire { gate: i, input: j }); + wire_values[j][i] = self.try_get_target(t).unwrap_or(F::ZERO); + } + } + MatrixWitness { wire_values } } } diff --git a/src/lib.rs b/src/lib.rs index 1249e9db..56fd179f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ #![feature(destructuring_assignment)] +#![feature(const_fn_trait_bound)] pub mod field; pub mod fri; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 1fb70bad..b9a6ec6e 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -19,12 +19,13 @@ use crate::hash::hashing::hash_n_to_hash; use crate::iop::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator}; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; +use crate::iop::witness::{PartialWitness, PartitionWitness}; use crate::plonk::circuit_data::{ CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData, VerifierCircuitData, VerifierOnlyCircuitData, }; use crate::plonk::copy_constraint::CopyConstraint; -use crate::plonk::permutation_argument::{ForestNode, TargetPartition}; +use crate::plonk::permutation_argument::ForestNode; use crate::plonk::plonk_common::PlonkPolynomials; use crate::polynomial::polynomial::PolynomialValues; use crate::util::context_tree::ContextTree; @@ -510,16 +511,14 @@ impl, const D: usize> CircuitBuilder { &self, k_is: &[F], subgroup: &[F], - ) -> (Vec>, Vec>) { + ) -> (Vec>, PartitionWitness) { let degree = self.gate_instances.len(); let degree_log = log2_strict(degree); - let mut target_partition = TargetPartition::new(|t| match t { - Target::Wire(Wire { gate, input }) => gate * self.config.num_routed_wires + input, - Target::VirtualTarget { index } => degree * self.config.num_routed_wires + index, - }); + let mut target_partition = + PartitionWitness::new(self.config.num_wires, self.config.num_routed_wires, degree); for gate in 0..degree { - for input in 0..self.config.num_routed_wires { + for input in 0..self.config.num_wires { target_partition.add(Target::Wire(Wire { gate, input })); } } diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index f3d13fb4..379e0d61 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -11,7 +11,7 @@ use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::generator::WitnessGenerator; use crate::iop::target::Target; -use crate::iop::witness::PartialWitness; +use crate::iop::witness::{PartialWitness, PartitionWitness}; use crate::plonk::copy_constraint::CopyConstraint; use crate::plonk::permutation_argument::ForestNode; use crate::plonk::proof::ProofWithPublicInputs; @@ -157,7 +157,7 @@ pub(crate) struct ProverOnlyCircuitData, const D: usize> { /// Number of virtual targets used in the circuit. pub num_virtual_targets: usize, - pub partition: Vec>, + pub partition: PartitionWitness, } /// Circuit data required by the verifier, but not the prover. diff --git a/src/plonk/permutation_argument.rs b/src/plonk/permutation_argument.rs index 3a9a93aa..ec8f3629 100644 --- a/src/plonk/permutation_argument.rs +++ b/src/plonk/permutation_argument.rs @@ -7,6 +7,7 @@ use rayon::prelude::*; use crate::field::field_types::Field; use crate::iop::target::Target; use crate::iop::wire::Wire; +use crate::iop::witness::PartitionWitness; use crate::polynomial::polynomial::PolynomialValues; /// Node in the Disjoint Set Forest. @@ -20,27 +21,21 @@ pub struct ForestNode { } /// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure. -#[derive(Debug, Clone)] -pub struct TargetPartition usize> { - forest: Vec>, - /// Function to compute a node's index in the forest. - indices: F, -} - -impl usize> - TargetPartition -{ - pub fn new(f: F) -> Self { +impl PartitionWitness { + pub fn new(num_wires: usize, num_routed_wires: usize, degree: usize) -> Self { Self { - forest: Vec::new(), - indices: f, + nodes: vec![], + num_wires, + num_routed_wires, + degree, } } + /// Add a new partition with a single member. - pub fn add(&mut self, t: T) { - let index = self.forest.len(); - debug_assert_eq!((self.indices)(t), index); - self.forest.push(ForestNode { + pub fn add(&mut self, t: Target) { + let index = self.nodes.len(); + debug_assert_eq!(self.target_index(t), index); + self.nodes.push(ForestNode { t, parent: index, size: 1, @@ -50,10 +45,10 @@ impl usize> } /// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives. - pub fn find(&mut self, x: ForestNode) -> ForestNode { + pub fn find(&mut self, x: ForestNode) -> ForestNode { if x.parent != x.index { - let root = self.find(self.forest[x.parent]); - self.forest[x.index].parent = root.index; + let root = self.find(self.nodes[x.parent]); + self.nodes[x.index].parent = root.index; root } else { x @@ -61,9 +56,9 @@ impl usize> } /// Merge two sets. - pub fn merge(&mut self, tx: T, ty: T) { - let mut x = self.forest[(self.indices)(tx)]; - let mut y = self.forest[(self.indices)(ty)]; + pub fn merge(&mut self, tx: Target, ty: Target) { + let mut x = self.nodes[self.target_index(tx)]; + let mut y = self.nodes[self.target_index(ty)]; x = self.find(x); y = self.find(y); @@ -80,39 +75,32 @@ impl usize> y.size += x.size; } - self.forest[x.index] = x; - self.forest[y.index] = y; + self.nodes[x.index] = x; + self.nodes[y.index] = y; } } -impl usize> TargetPartition { - pub fn wire_partition(mut self) -> (WirePartitions, Vec>) { +impl PartitionWitness { + pub fn wire_partition(mut self) -> (WirePartitions, Self) { let mut partition = HashMap::<_, Vec<_>>::new(); - let nodes = self.forest.clone(); - for x in nodes { - let v = partition.entry(self.find(x).t).or_default(); - v.push(x.t); + for gate in 0..self.degree { + for input in 0..self.num_routed_wires { + let w = Wire { gate, input }; + let t = Target::Wire(w); + let x = self.nodes[self.target_index(t)]; + partition.entry(self.find(x).t).or_default().push(w); + } + } + // I'm not 100% sure this loop is needed, but I'm afraid removing it might lead to subtle bugs. + for index in 0..self.nodes.len() - self.degree * self.num_wires { + let t = Target::VirtualTarget { index }; + let x = self.nodes[self.target_index(t)]; + self.find(x); } - // let mut indices = HashMap::new(); // Here we keep just the Wire targets, filtering out everything else. - let partition = partition - .into_values() - .map(|v| { - v.into_iter() - .filter_map(|t| match t { - Target::Wire(w) => Some(w), - _ => None, - }) - .collect::>() - }) - .collect::>(); - // partition.iter().enumerate().for_each(|(i, v)| { - // v.iter().for_each(|t| { - // indices.insert(*t, i); - // }); - // }); + let partition = partition.into_values().collect::>(); - (WirePartitions { partition }, self.forest) + (WirePartitions { partition }, self) } } diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 1f44547e..d27c2443 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -35,58 +35,23 @@ pub(crate) fn prove, const D: usize>( let num_challenges = config.num_challenges; let quotient_degree = common_data.quotient_degree(); let degree = common_data.degree(); - // for i in 0..prover_data.gate_instances.len() { - // println!("{}: {}", i, prover_data.gate_instances[i].gate_ref.0.id()); - // } - let nrw = config.num_routed_wires; - let nw = config.num_wires; - let nvt = prover_data.num_virtual_targets; - let target_index = move |t: Target| -> usize { - match t { - Target::Wire(Wire { gate, input }) if input < nrw => gate * nrw + input, - Target::Wire(Wire { gate, input }) if input >= nrw => { - degree * nrw + nvt + gate * (nw - nrw) + input - nrw - } - Target::VirtualTarget { index } => degree * nrw + index, - _ => unreachable!(), - } - }; let mut partial_witness = prover_data.partition.clone(); - let n = partial_witness.len(); - timed!(timing, "fill partition", { - partial_witness.reserve_exact(degree * (config.num_wires - config.num_routed_wires)); - for i in 0..degree * (config.num_wires - config.num_routed_wires) { - partial_witness.push(ForestNode { - t: Target::Wire(Wire { gate: 0, input: 0 }), - parent: n + i, - size: 0, - index: n + i, - value: None, - }) - } + timed!( + timing, + "fill partition", for &(t, v) in &inputs.set_targets { - // println!("{:?} {} {}", t, target_index(t), partial_witness.len()); - let parent = partial_witness[target_index(t)].parent; - // println!("{} {}", parent, partial_witness.len()); - partial_witness[parent].value = Some(v); + partial_witness.set_target(t, v); } - }); - // let t = partial_witness[target_index(Target::Wire(Wire { - // gate: 14, - // input: 16, - // }))]; - // dbg!(t); - // dbg!(partial_witness[t.parent]); - // let mut partial_witness = inputs; - let mut partial_witness = PartitionWitness(partial_witness, Box::new(target_index)); + ); + timed!( timing, &format!("run {} generators", prover_data.generators.len()), generate_partial_witness( &mut partial_witness, &prover_data.generators, - config.num_wires, + num_wires, degree, prover_data.num_virtual_targets, &mut timing @@ -96,22 +61,17 @@ pub(crate) fn prove, const D: usize>( let public_inputs = partial_witness.get_targets(&prover_data.public_inputs); let public_inputs_hash = hash_n_to_hash(public_inputs.clone(), true); - // // Display the marked targets for debugging purposes. - // for m in &prover_data.marked_targets { - // m.display(&partial_witness); - // } - // - // timed!( - // timing, - // "check copy constraints", - // partial_witness - // .check_copy_constraints(&prover_data.copy_constraints, &prover_data.gate_instances)? - // ); + if cfg!(debug_assertions) { + // Display the marked targets for debugging purposes. + for m in &prover_data.marked_targets { + m.display(&partial_witness); + } + } let witness = timed!( timing, "compute full witness", - partial_witness.full_witness(degree, num_wires) + partial_witness.full_witness() ); let wires_values: Vec> = timed!( diff --git a/src/util/marking.rs b/src/util/marking.rs index 6e0ad993..5019511a 100644 --- a/src/util/marking.rs +++ b/src/util/marking.rs @@ -2,7 +2,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::hash::hash_types::HashOutTarget; use crate::iop::target::Target; -use crate::iop::witness::{PartialWitness, Witness}; +use crate::iop::witness::{PartialWitness, PartitionWitness, Witness}; /// Enum representing all types of targets, so that they can be marked. #[derive(Clone)] @@ -36,7 +36,7 @@ impl>, const D: usize> From> for Markable { impl Markable { /// Display a `Markable` by querying a partial witness. - fn print_markable>(&self, pw: &PartialWitness) { + fn print_markable>(&self, pw: &PartitionWitness) { match self { Markable::Target(t) => println!("{}", pw.get_target(*t)), Markable::ExtensionTarget(et) => println!("{}", pw.get_extension_target(*et)), @@ -55,7 +55,7 @@ pub struct MarkedTargets { impl MarkedTargets { /// Display the collection of targets along with its name by querying a partial witness. - pub fn display>(&self, pw: &PartialWitness) { + pub fn display>(&self, pw: &PartitionWitness) { println!("Values for {}:", self.name); self.targets.print_markable(pw); println!("End of values for {}", self.name);