From c83382aaaabddddb0b0ff137388c79696a5baca6 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 1 Jul 2021 15:20:16 +0200 Subject: [PATCH] Working partial products --- src/circuit_data.rs | 2 +- src/gadgets/arithmetic_extension.rs | 4 +- src/plonk_common.rs | 78 +++++++++++++++---------- src/prover.rs | 91 ++++++++++++++++------------- src/util/partial_products.rs | 46 ++++++++++++--- 5 files changed, 143 insertions(+), 78 deletions(-) diff --git a/src/circuit_data.rs b/src/circuit_data.rs index 06a6700f..9741bdc4 100644 --- a/src/circuit_data.rs +++ b/src/circuit_data.rs @@ -202,7 +202,7 @@ impl, const D: usize> CommonCircuitData { self.num_constants..self.num_constants + self.config.num_routed_wires } - /// Range of the constants polynomials in the `constants_sigmas_commitment`. + /// Range of the `z`s polynomials in the ``. pub fn zs_range(&self) -> Range { 0..self.config.num_challenges } diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index ce53c59f..24f20778 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -446,7 +446,9 @@ mod tests { type FF = QuarticCrandallField; const D: usize = 4; - let config = CircuitConfig::large_config(); + let mut config = CircuitConfig::large_config(); + config.rate_bits = 2; + config.fri_config.rate_bits = 2; let mut builder = CircuitBuilder::::new(config); diff --git a/src/plonk_common.rs b/src/plonk_common.rs index 3b271101..b1910175 100644 --- a/src/plonk_common.rs +++ b/src/plonk_common.rs @@ -9,7 +9,7 @@ use crate::gates::gate::{GateRef, PrefixedGate}; use crate::polynomial::commitment::SALT_SIZE; use crate::polynomial::polynomial::PolynomialCoeffs; use crate::target::Target; -use crate::util::partial_products::partial_products; +use crate::util::partial_products::{check_partial_products, partial_products}; use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// Holds the Merkle tree index and blinding flag of a set of polynomials used in FRI. @@ -124,6 +124,7 @@ pub(crate) fn eval_vanishing_poly_base, const D: usize>( alphas: &[F], z_h_on_coset: &ZeroPolyOnCoset, ) -> Vec { + let max_degree = common_data.max_filtered_constraint_degree; let constraint_terms = evaluate_gate_constraints_base(&common_data.gates, common_data.num_gate_constraints, vars); @@ -154,36 +155,50 @@ pub(crate) fn eval_vanishing_poly_base, const D: usize>( wire_value + betas[i] * s_sigma + gammas[i] }) .collect::>(); - let numerator_partial_products = - partial_products(numerator_values, common_data.max_filtered_constraint_degree); - let denominator_partial_products = partial_products( - denominator_values, - common_data.max_filtered_constraint_degree, - ); + let numerator_partial_products = partial_products(&numerator_values, max_degree); + let denominator_partial_products = partial_products(&denominator_values, max_degree); - dbg!(numerator_partial_products - .clone() - .0 - .into_iter() - .chain(denominator_partial_products.clone().0) - .zip(local_partial_products) - .map(|(a, &b)| a - b) - .collect::>(),); - vanishing_partial_products_terms.append( - &mut numerator_partial_products - .0 - .into_iter() - .chain(denominator_partial_products.0) - .zip(local_partial_products) - .map(|(a, &b)| a - b) - .collect::>(), - ); - dbg!(&numerator_partial_products.1); - dbg!(&denominator_partial_products.1); - dbg!(common_data.max_filtered_constraint_degree); - let f_prime: F = numerator_partial_products.1.into_iter().product(); - let g_prime: F = denominator_partial_products.1.into_iter().product(); - // vanishing_v_shift_terms.push(f_prime * z_x - g_prime * z_gz); + let num_prods = numerator_partial_products.0.len(); + // dbg!(numerator_partial_products + // .0 + // .iter() + // .chain(&denominator_partial_products.0) + // .zip(&local_partial_products[i * num_prods..(i + 1) * num_prods]) + // .map(|(&a, &b)| a - b) + // .collect::>(),); + // vanishing_partial_products_terms.append( + // &mut numerator_partial_products + // .0 + // .into_iter() + // .chain(denominator_partial_products.0) + // .zip(&local_partial_products[i * num_prods..(i + 1) * num_prods]) + // .map(|(a, &b)| a - b) + // .collect::>(), + // ); + vanishing_partial_products_terms.extend(check_partial_products( + &numerator_values, + &local_partial_products[2 * i * num_prods..(2 * i + 1) * num_prods], + max_degree, + )); + vanishing_partial_products_terms.extend(check_partial_products( + &denominator_values, + &local_partial_products[(2 * i + 1) * num_prods..(2 * i + 2) * num_prods], + max_degree, + )); + // dbg!(common_data.max_filtered_constraint_degree); + // dbg!(numerator_partial_products.1.len()); + // dbg!(denominator_partial_products.1.len()); + let f_prime: F = local_partial_products + [(2 * i + 1) * num_prods - numerator_partial_products.1..(2 * i + 1) * num_prods] + .iter() + .copied() + .product(); + let g_prime: F = local_partial_products + [(2 * i + 2) * num_prods - numerator_partial_products.1..(2 * i + 2) * num_prods] + .iter() + .copied() + .product(); + vanishing_v_shift_terms.push(f_prime * z_x - g_prime * z_gz); } let vanishing_terms = [ @@ -193,6 +208,9 @@ pub(crate) fn eval_vanishing_poly_base, const D: usize>( constraint_terms, ] .concat(); + // if index % 4 == 0 { + // println!("{}", vanishing_terms.iter().all(|x| x.is_zero())); + // } reduce_with_powers_multi(&vanishing_terms, alphas) } diff --git a/src/prover.rs b/src/prover.rs index de5c972c..fd0a093d 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -80,7 +80,7 @@ pub(crate) fn prove, const D: usize>( let betas = challenger.get_n_challenges(num_challenges); let gammas = challenger.get_n_challenges(num_challenges); - let partial_products = timed!( + let mut partial_products = timed!( all_wires_permutation_partial_products(&witness, &betas, &gammas, prover_data, common_data), "to compute partial products" ); @@ -90,6 +90,10 @@ pub(crate) fn prove, const D: usize>( "to compute Z's" ); + partial_products.iter_mut().for_each(|part| { + part.drain(0..2); + }); + let zs_partial_products = [plonk_z_vecs, partial_products.concat()].concat(); let plonk_zs_commitment = timed!( ListPolynomialCommitment::new( @@ -205,35 +209,50 @@ fn wires_permutation_partial_products, const D: usize>( ) -> Vec> { let degree = common_data.max_filtered_constraint_degree; let subgroup = &prover_data.subgroup; - let mut values = Vec::new(); let k_is = &common_data.k_is; - for i in 1..common_data.degree() { - let x = subgroup[i - 1]; - let s_sigmas = &prover_data.sigmas[i - 1]; - let numerator_values = (0..common_data.config.num_routed_wires) - .map(|j| { - let wire_value = witness.get_wire(i - 1, j); - let k_i = k_is[j]; - let s_id = k_i * x; - wire_value + beta * s_id + gamma - }) - .collect::>(); - let denominator_values = (0..common_data.config.num_routed_wires) - .map(|j| { - let wire_value = witness.get_wire(i - 1, j); - let s_sigma = s_sigmas[j]; - wire_value + beta * s_sigma + gamma - }) - .collect::>(); - let partials = [ - partial_products(numerator_values, degree).0, - partial_products(denominator_values, degree).0, - ] - .concat(); - values.push(partials); - } + let values = subgroup + .iter() + .enumerate() + .map(|(i, &x)| { + let s_sigmas = &prover_data.sigmas[i]; + let numerator_values = (0..common_data.config.num_routed_wires) + .map(|j| { + let wire_value = witness.get_wire(i, j); + let k_i = k_is[j]; + let s_id = k_i * x; + wire_value + beta * s_id + gamma + }) + .collect::>(); + let denominator_values = (0..common_data.config.num_routed_wires) + .map(|j| { + let wire_value = witness.get_wire(i, j); + let s_sigma = s_sigmas[j]; + wire_value + beta * s_sigma + gamma + }) + .collect::>(); + let numerator_partials = partial_products(&numerator_values, degree); + let denominator_partials = partial_products(&denominator_values, degree); + let numerator = numerator_partials.0 + [numerator_partials.0.len() - numerator_partials.1..] + .iter() + .copied() + .product(); + let denominator = denominator_partials.0 + [denominator_partials.0.len() - denominator_partials.1..] + .iter() + .copied() + .product(); + + [ + vec![numerator], + vec![denominator], + numerator_partials.0, + denominator_partials.0, + ] + .concat() + }) + .collect::>(); - values.insert(0, vec![F::ONE; values[0].len()]); transpose(&values) .into_par_iter() .map(PolynomialValues::new) @@ -255,20 +274,12 @@ fn compute_z, const D: usize>( prover_data: &ProverOnlyCircuitData, common_data: &CommonCircuitData, ) -> PolynomialValues { - let num_partials = partial_products.len() / 2; let subgroup = &prover_data.subgroup; let mut plonk_z_points = vec![F::ONE]; - let k_is = &common_data.k_is; for i in 1..common_data.degree() { let x = subgroup[i - 1]; - let numerator = partial_products[..num_partials] - .iter() - .map(|vs| vs.values[i]) - .product(); - let denominator = partial_products[num_partials..] - .iter() - .map(|vs| vs.values[i]) - .product(); + let numerator = partial_products[0].values[i - 1]; + let denominator = partial_products[1].values[i - 1]; let last = *plonk_z_points.last().unwrap(); plonk_z_points.push(last * numerator / denominator); } @@ -312,7 +323,8 @@ fn compute_quotient_polys<'a, F: Extendable, const D: usize>( ZeroPolyOnCoset::new(common_data.degree_bits, max_filtered_constraint_degree_bits); let quotient_values: Vec> = points - .into_par_iter() + // .into_par_iter() + .into_iter() .enumerate() .map(|(i, x)| { let shifted_x = F::coset_shift() * x; @@ -335,6 +347,7 @@ fn compute_quotient_polys<'a, F: Extendable, const D: usize>( local_constants, local_wires, }; + dbg!(i); let mut quotient_values = eval_vanishing_poly_base( common_data, i, diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index c31fbff6..c5426eea 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -1,8 +1,9 @@ use std::iter::Product; +use std::ops::Sub; -pub fn partial_products(v: Vec, max_degree: usize) -> (Vec, Vec) { +pub fn partial_products(v: &[T], max_degree: usize) -> (Vec, usize) { let mut res = Vec::new(); - let mut remainder = v; + let mut remainder = v.to_vec(); while remainder.len() >= max_degree { let new_partials = remainder .chunks(max_degree) @@ -13,18 +14,49 @@ pub fn partial_products(v: Vec, max_degree: usize) -> (Vec remainder = new_partials; } - (res, remainder) + (res, remainder.len()) +} + +pub fn check_partial_products>( + v: &[T], + partials: &[T], + max_degree: usize, +) -> Vec { + let mut res = Vec::new(); + let mut remainder = v.to_vec(); + let mut partials = partials.to_vec(); + while remainder.len() >= max_degree { + let products = remainder + .chunks(max_degree) + .map(|chunk| chunk.iter().copied().product()) + .collect::>(); + res.extend(products.iter().zip(&partials).map(|(&a, &b)| a - b)); + remainder = partials.drain(..products.len()).collect(); + } + + res } #[cfg(test)] mod tests { + use num::Zero; + use super::*; #[test] fn test_partial_products() { - assert_eq!( - partial_products(vec![1, 2, 3, 4, 5, 6], 2), - (vec![2, 12, 30, 24, 30], vec![24, 30]) - ); + let v = vec![1, 2, 3, 4, 5, 6]; + let p = partial_products(&v, 2); + assert_eq!(p, (vec![2, 12, 30, 24, 30, 720], 1)); + assert!(check_partial_products(&v, &p.0, 2) + .iter() + .all(|x| x.is_zero())); + + let v = vec![1, 2, 3, 4, 5, 6]; + let p = partial_products(&v, 3); + assert_eq!(p, (vec![6, 120], 2)); + assert!(check_partial_products(&v, &p.0, 3) + .iter() + .all(|x| x.is_zero())); } }