This commit is contained in:
wborgeaud 2021-08-20 10:44:19 +02:00
parent c6cf5cf130
commit 6584734928
8 changed files with 100 additions and 132 deletions

View File

@ -30,13 +30,13 @@ pub(crate) fn generate_partial_witness<F: Field>(
// Target::VirtualTarget { index } => degree * num_wires + index,
// }
// };
let max_target_index = witness.0.len();
let max_target_index = witness.nodes.len();
// Index generator indices by their watched targets.
let mut generator_indices_by_watches = vec![Vec::new(); max_target_index];
timed!(timing, "index generators by their watched targets", {
for (i, generator) in generators.iter().enumerate() {
for watch in generator.watch_list() {
generator_indices_by_watches[witness.1(watch)].push(i);
generator_indices_by_watches[witness.target_index(watch)].push(i);
}
}
});
@ -71,7 +71,9 @@ pub(crate) fn generate_partial_witness<F: Field>(
// Enqueue unfinished generators that were watching one of the newly populated targets.
for &(watch, _) in &buffer.target_values {
for &watching_generator_idx in &generator_indices_by_watches[witness.1(watch)] {
for &watching_generator_idx in
&generator_indices_by_watches[witness.target_index(watch)]
{
if !generator_is_expired[watching_generator_idx] {
next_pending_generator_indices.push(watching_generator_idx);
}

View File

@ -217,32 +217,50 @@ impl<F: Field> Witness<F> for PartialWitness<F> {
}
}
pub struct PartitionWitness<F: Field>(
pub Vec<ForestNode<Target, F>>,
pub Box<dyn Fn(Target) -> usize>,
);
#[derive(Clone)]
pub struct PartitionWitness<F: Field> {
pub nodes: Vec<ForestNode<Target, F>>,
pub num_wires: usize,
pub num_routed_wires: usize,
pub degree: usize,
}
impl<F: Field> Witness<F> for PartitionWitness<F> {
fn try_get_target(&self, target: Target) -> Option<F> {
self.0[self.0[self.1(target)].parent].value
self.nodes[self.nodes[self.target_index(target)].parent].value
}
fn set_target(&mut self, target: Target, value: F) {
let i = self.0[self.1(target)].parent;
self.0[i].value = Some(value);
let i = self.nodes[self.target_index(target)].parent;
self.nodes[i].value = Some(value);
}
}
impl<F: Field> PartitionWitness<F> {
pub fn full_witness(self, degree: usize, num_wires: usize) -> MatrixWitness<F> {
let mut wire_values = vec![vec![F::ZERO; degree]; num_wires];
// assert!(self.wire_values.len() <= degree);
for i in 0..degree {
for j in 0..num_wires {
let t = Target::Wire(Wire { gate: i, input: j });
wire_values[j][i] = self.0[self.0[self.1(t)].parent].value.unwrap_or(F::ZERO);
pub const fn target_index(&self, target: Target) -> usize {
match target {
Target::Wire(Wire { gate, input }) => gate * self.num_wires + input,
Target::VirtualTarget { index } => self.degree * self.num_wires + index,
}
}
pub fn full_witness(self) -> MatrixWitness<F> {
let mut wire_values = vec![vec![]; self.num_wires];
for j in 0..self.num_wires {
wire_values[j].reserve_exact(self.degree);
unsafe {
// After .reserve_exact(l), wire_values[i] will have capacity at least l. Hence, set_len
// will not cause the buffer to overrun.
wire_values[j].set_len(self.degree);
}
}
for i in 0..self.degree {
for j in 0..self.num_wires {
let t = Target::Wire(Wire { gate: i, input: j });
wire_values[j][i] = self.try_get_target(t).unwrap_or(F::ZERO);
}
}
MatrixWitness { wire_values }
}
}

View File

@ -1,4 +1,5 @@
#![feature(destructuring_assignment)]
#![feature(const_fn_trait_bound)]
pub mod field;
pub mod fri;

View File

@ -19,12 +19,13 @@ use crate::hash::hashing::hash_n_to_hash;
use crate::iop::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator};
use crate::iop::target::{BoolTarget, Target};
use crate::iop::wire::Wire;
use crate::iop::witness::{PartialWitness, PartitionWitness};
use crate::plonk::circuit_data::{
CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData,
VerifierCircuitData, VerifierOnlyCircuitData,
};
use crate::plonk::copy_constraint::CopyConstraint;
use crate::plonk::permutation_argument::{ForestNode, TargetPartition};
use crate::plonk::permutation_argument::ForestNode;
use crate::plonk::plonk_common::PlonkPolynomials;
use crate::polynomial::polynomial::PolynomialValues;
use crate::util::context_tree::ContextTree;
@ -510,16 +511,14 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
&self,
k_is: &[F],
subgroup: &[F],
) -> (Vec<PolynomialValues<F>>, Vec<ForestNode<Target, F>>) {
) -> (Vec<PolynomialValues<F>>, PartitionWitness<F>) {
let degree = self.gate_instances.len();
let degree_log = log2_strict(degree);
let mut target_partition = TargetPartition::new(|t| match t {
Target::Wire(Wire { gate, input }) => gate * self.config.num_routed_wires + input,
Target::VirtualTarget { index } => degree * self.config.num_routed_wires + index,
});
let mut target_partition =
PartitionWitness::new(self.config.num_wires, self.config.num_routed_wires, degree);
for gate in 0..degree {
for input in 0..self.config.num_routed_wires {
for input in 0..self.config.num_wires {
target_partition.add(Target::Wire(Wire { gate, input }));
}
}

View File

@ -11,7 +11,7 @@ use crate::hash::hash_types::{HashOut, MerkleCapTarget};
use crate::hash::merkle_tree::MerkleCap;
use crate::iop::generator::WitnessGenerator;
use crate::iop::target::Target;
use crate::iop::witness::PartialWitness;
use crate::iop::witness::{PartialWitness, PartitionWitness};
use crate::plonk::copy_constraint::CopyConstraint;
use crate::plonk::permutation_argument::ForestNode;
use crate::plonk::proof::ProofWithPublicInputs;
@ -157,7 +157,7 @@ pub(crate) struct ProverOnlyCircuitData<F: Extendable<D>, const D: usize> {
/// Number of virtual targets used in the circuit.
pub num_virtual_targets: usize,
pub partition: Vec<ForestNode<Target, F>>,
pub partition: PartitionWitness<F>,
}
/// Circuit data required by the verifier, but not the prover.

View File

@ -7,6 +7,7 @@ use rayon::prelude::*;
use crate::field::field_types::Field;
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::PartitionWitness;
use crate::polynomial::polynomial::PolynomialValues;
/// Node in the Disjoint Set Forest.
@ -20,27 +21,21 @@ pub struct ForestNode<T: Debug + Copy + Eq + PartialEq, V: Field> {
}
/// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure.
#[derive(Debug, Clone)]
pub struct TargetPartition<T: Debug + Copy + Eq + PartialEq + Hash, V: Field, F: Fn(T) -> usize> {
forest: Vec<ForestNode<T, V>>,
/// Function to compute a node's index in the forest.
indices: F,
}
impl<T: Debug + Copy + Eq + PartialEq + Hash, V: Field, F: Fn(T) -> usize>
TargetPartition<T, V, F>
{
pub fn new(f: F) -> Self {
impl<F: Field> PartitionWitness<F> {
pub fn new(num_wires: usize, num_routed_wires: usize, degree: usize) -> Self {
Self {
forest: Vec::new(),
indices: f,
nodes: vec![],
num_wires,
num_routed_wires,
degree,
}
}
/// Add a new partition with a single member.
pub fn add(&mut self, t: T) {
let index = self.forest.len();
debug_assert_eq!((self.indices)(t), index);
self.forest.push(ForestNode {
pub fn add(&mut self, t: Target) {
let index = self.nodes.len();
debug_assert_eq!(self.target_index(t), index);
self.nodes.push(ForestNode {
t,
parent: index,
size: 1,
@ -50,10 +45,10 @@ impl<T: Debug + Copy + Eq + PartialEq + Hash, V: Field, F: Fn(T) -> usize>
}
/// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives.
pub fn find(&mut self, x: ForestNode<T, V>) -> ForestNode<T, V> {
pub fn find(&mut self, x: ForestNode<Target, F>) -> ForestNode<Target, F> {
if x.parent != x.index {
let root = self.find(self.forest[x.parent]);
self.forest[x.index].parent = root.index;
let root = self.find(self.nodes[x.parent]);
self.nodes[x.index].parent = root.index;
root
} else {
x
@ -61,9 +56,9 @@ impl<T: Debug + Copy + Eq + PartialEq + Hash, V: Field, F: Fn(T) -> usize>
}
/// Merge two sets.
pub fn merge(&mut self, tx: T, ty: T) {
let mut x = self.forest[(self.indices)(tx)];
let mut y = self.forest[(self.indices)(ty)];
pub fn merge(&mut self, tx: Target, ty: Target) {
let mut x = self.nodes[self.target_index(tx)];
let mut y = self.nodes[self.target_index(ty)];
x = self.find(x);
y = self.find(y);
@ -80,39 +75,32 @@ impl<T: Debug + Copy + Eq + PartialEq + Hash, V: Field, F: Fn(T) -> usize>
y.size += x.size;
}
self.forest[x.index] = x;
self.forest[y.index] = y;
self.nodes[x.index] = x;
self.nodes[y.index] = y;
}
}
impl<V: Field, F: Fn(Target) -> usize> TargetPartition<Target, V, F> {
pub fn wire_partition(mut self) -> (WirePartitions, Vec<ForestNode<Target, V>>) {
impl<F: Field> PartitionWitness<F> {
pub fn wire_partition(mut self) -> (WirePartitions, Self) {
let mut partition = HashMap::<_, Vec<_>>::new();
let nodes = self.forest.clone();
for x in nodes {
let v = partition.entry(self.find(x).t).or_default();
v.push(x.t);
for gate in 0..self.degree {
for input in 0..self.num_routed_wires {
let w = Wire { gate, input };
let t = Target::Wire(w);
let x = self.nodes[self.target_index(t)];
partition.entry(self.find(x).t).or_default().push(w);
}
}
// I'm not 100% sure this loop is needed, but I'm afraid removing it might lead to subtle bugs.
for index in 0..self.nodes.len() - self.degree * self.num_wires {
let t = Target::VirtualTarget { index };
let x = self.nodes[self.target_index(t)];
self.find(x);
}
// let mut indices = HashMap::new();
// Here we keep just the Wire targets, filtering out everything else.
let partition = partition
.into_values()
.map(|v| {
v.into_iter()
.filter_map(|t| match t {
Target::Wire(w) => Some(w),
_ => None,
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
// partition.iter().enumerate().for_each(|(i, v)| {
// v.iter().for_each(|t| {
// indices.insert(*t, i);
// });
// });
let partition = partition.into_values().collect::<Vec<_>>();
(WirePartitions { partition }, self.forest)
(WirePartitions { partition }, self)
}
}

View File

@ -35,58 +35,23 @@ pub(crate) fn prove<F: Extendable<D>, const D: usize>(
let num_challenges = config.num_challenges;
let quotient_degree = common_data.quotient_degree();
let degree = common_data.degree();
// for i in 0..prover_data.gate_instances.len() {
// println!("{}: {}", i, prover_data.gate_instances[i].gate_ref.0.id());
// }
let nrw = config.num_routed_wires;
let nw = config.num_wires;
let nvt = prover_data.num_virtual_targets;
let target_index = move |t: Target| -> usize {
match t {
Target::Wire(Wire { gate, input }) if input < nrw => gate * nrw + input,
Target::Wire(Wire { gate, input }) if input >= nrw => {
degree * nrw + nvt + gate * (nw - nrw) + input - nrw
}
Target::VirtualTarget { index } => degree * nrw + index,
_ => unreachable!(),
}
};
let mut partial_witness = prover_data.partition.clone();
let n = partial_witness.len();
timed!(timing, "fill partition", {
partial_witness.reserve_exact(degree * (config.num_wires - config.num_routed_wires));
for i in 0..degree * (config.num_wires - config.num_routed_wires) {
partial_witness.push(ForestNode {
t: Target::Wire(Wire { gate: 0, input: 0 }),
parent: n + i,
size: 0,
index: n + i,
value: None,
})
}
timed!(
timing,
"fill partition",
for &(t, v) in &inputs.set_targets {
// println!("{:?} {} {}", t, target_index(t), partial_witness.len());
let parent = partial_witness[target_index(t)].parent;
// println!("{} {}", parent, partial_witness.len());
partial_witness[parent].value = Some(v);
partial_witness.set_target(t, v);
}
});
// let t = partial_witness[target_index(Target::Wire(Wire {
// gate: 14,
// input: 16,
// }))];
// dbg!(t);
// dbg!(partial_witness[t.parent]);
// let mut partial_witness = inputs;
let mut partial_witness = PartitionWitness(partial_witness, Box::new(target_index));
);
timed!(
timing,
&format!("run {} generators", prover_data.generators.len()),
generate_partial_witness(
&mut partial_witness,
&prover_data.generators,
config.num_wires,
num_wires,
degree,
prover_data.num_virtual_targets,
&mut timing
@ -96,22 +61,17 @@ pub(crate) fn prove<F: Extendable<D>, const D: usize>(
let public_inputs = partial_witness.get_targets(&prover_data.public_inputs);
let public_inputs_hash = hash_n_to_hash(public_inputs.clone(), true);
// // Display the marked targets for debugging purposes.
// for m in &prover_data.marked_targets {
// m.display(&partial_witness);
// }
//
// timed!(
// timing,
// "check copy constraints",
// partial_witness
// .check_copy_constraints(&prover_data.copy_constraints, &prover_data.gate_instances)?
// );
if cfg!(debug_assertions) {
// Display the marked targets for debugging purposes.
for m in &prover_data.marked_targets {
m.display(&partial_witness);
}
}
let witness = timed!(
timing,
"compute full witness",
partial_witness.full_witness(degree, num_wires)
partial_witness.full_witness()
);
let wires_values: Vec<PolynomialValues<F>> = timed!(

View File

@ -2,7 +2,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::hash::hash_types::HashOutTarget;
use crate::iop::target::Target;
use crate::iop::witness::{PartialWitness, Witness};
use crate::iop::witness::{PartialWitness, PartitionWitness, Witness};
/// Enum representing all types of targets, so that they can be marked.
#[derive(Clone)]
@ -36,7 +36,7 @@ impl<M: Into<Markable<D>>, const D: usize> From<Vec<M>> for Markable<D> {
impl<const D: usize> Markable<D> {
/// Display a `Markable` by querying a partial witness.
fn print_markable<F: Extendable<D>>(&self, pw: &PartialWitness<F>) {
fn print_markable<F: Extendable<D>>(&self, pw: &PartitionWitness<F>) {
match self {
Markable::Target(t) => println!("{}", pw.get_target(*t)),
Markable::ExtensionTarget(et) => println!("{}", pw.get_extension_target(*et)),
@ -55,7 +55,7 @@ pub struct MarkedTargets<const D: usize> {
impl<const D: usize> MarkedTargets<D> {
/// Display the collection of targets along with its name by querying a partial witness.
pub fn display<F: Extendable<D>>(&self, pw: &PartialWitness<F>) {
pub fn display<F: Extendable<D>>(&self, pw: &PartitionWitness<F>) {
println!("Values for {}:", self.name);
self.targets.print_markable(pw);
println!("End of values for {}", self.name);