Remove quotients and work directly with numerators and denominators in partial products check

This commit is contained in:
wborgeaud 2021-11-10 18:13:27 +01:00
parent ff943138f3
commit 32f09ac2df
2 changed files with 64 additions and 76 deletions

View File

@ -62,31 +62,26 @@ pub(crate) fn eval_vanishing_poly<F: RichField + Extendable<D>, const D: usize>(
wire_value + s_sigma.scalar_mul(betas[i]) + gammas[i].into()
})
.collect::<Vec<_>>();
let quotient_values = (0..common_data.config.num_routed_wires)
.map(|j| numerator_values[j] / denominator_values[j])
.collect::<Vec<_>>();
// The partial products considered for this iteration of `i`.
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
// Check the quotient partial products.
let mut partial_product_checks =
check_partial_products(&quotient_values, current_partial_products, max_degree);
// The partial products are products of quotients, so we multiply them by the product of the
// corresponding denominators to make sure they are polynomials.
for (j, partial_product_check) in partial_product_checks.iter_mut().enumerate() {
let range = j * max_degree..(j + 1) * max_degree;
*partial_product_check *= denominator_values[range].iter().copied().product();
}
let partial_product_checks = check_partial_products(
&numerator_values,
&denominator_values,
current_partial_products,
max_degree,
);
vanishing_partial_products_terms.extend(partial_product_checks);
let quotient: F::Extension = *current_partial_products.last().unwrap()
* quotient_values[final_num_prod..].iter().copied().product();
let mut v_shift_term = quotient * z_x - z_gz;
// Need to multiply by the denominators to make sure we get a polynomial.
v_shift_term *= denominator_values[final_num_prod..]
.iter()
.copied()
.product();
let v_shift_term = *current_partial_products.last().unwrap()
* numerator_values[final_num_prod..].iter().copied().product()
* z_x
- z_gz
* denominator_values[final_num_prod..]
.iter()
.copied()
.product();
vanishing_v_shift_terms.push(v_shift_term);
}
@ -139,7 +134,6 @@ pub(crate) fn eval_vanishing_poly_base_batch<F: RichField + Extendable<D>, const
let mut numerator_values = Vec::with_capacity(num_routed_wires);
let mut denominator_values = Vec::with_capacity(num_routed_wires);
let mut quotient_values = Vec::with_capacity(num_routed_wires);
// The L_1(x) (Z(x) - 1) vanishing terms.
let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges);
@ -178,37 +172,30 @@ pub(crate) fn eval_vanishing_poly_base_batch<F: RichField + Extendable<D>, const
let s_sigma = s_sigmas[j];
wire_value + betas[i] * s_sigma + gammas[i]
}));
let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values);
quotient_values.extend(
(0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]),
);
// The partial products considered for this iteration of `i`.
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
// Check the numerator partial products.
let mut partial_product_checks =
check_partial_products(&quotient_values, current_partial_products, max_degree);
// The partial products are products of quotients, so we multiply them by the product of the
// corresponding denominators to make sure they are polynomials.
for (j, partial_product_check) in partial_product_checks.iter_mut().enumerate() {
let range = j * max_degree..(j + 1) * max_degree;
*partial_product_check *= denominator_values[range].iter().copied().product();
}
let partial_product_checks = check_partial_products(
&numerator_values,
&denominator_values,
current_partial_products,
max_degree,
);
vanishing_partial_products_terms.extend(partial_product_checks);
let quotient: F = *current_partial_products.last().unwrap()
* quotient_values[final_num_prod..].iter().copied().product();
let mut v_shift_term = quotient * z_x - z_gz;
// Need to multiply by the denominators to make sure we get a polynomial.
v_shift_term *= denominator_values[final_num_prod..]
.iter()
.copied()
.product();
let v_shift_term = *current_partial_products.last().unwrap()
* numerator_values[final_num_prod..].iter().copied().product()
* z_x
- z_gz
* denominator_values[final_num_prod..]
.iter()
.copied()
.product();
vanishing_v_shift_terms.push(v_shift_term);
numerator_values.clear();
denominator_values.clear();
quotient_values.clear();
}
let vanishing_terms = vanishing_z_1_terms
@ -365,7 +352,6 @@ pub(crate) fn eval_vanishing_poly_recursively<F: RichField + Extendable<D>, cons
let mut numerator_values = Vec::new();
let mut denominator_values = Vec::new();
let mut quotient_values = Vec::new();
for j in 0..common_data.config.num_routed_wires {
let wire_value = vars.local_wires[j];
@ -378,46 +364,33 @@ pub(crate) fn eval_vanishing_poly_recursively<F: RichField + Extendable<D>, cons
let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma);
let denominator =
builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma);
let quotient = builder.div_extension(numerator, denominator);
numerator_values.push(numerator);
denominator_values.push(denominator);
quotient_values.push(quotient);
}
// The partial products considered for this iteration of `i`.
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
// Check the quotient partial products.
let mut partial_product_checks = check_partial_products_recursively(
let partial_product_checks = check_partial_products_recursively(
builder,
&quotient_values,
&numerator_values,
&denominator_values,
current_partial_products,
max_degree,
);
// The partial products are products of quotients, so we multiply them by the product of the
// corresponding denominators to make sure they are polynomials.
for (j, partial_product_check) in partial_product_checks.iter_mut().enumerate() {
let range = j * max_degree..(j + 1) * max_degree;
*partial_product_check = builder.mul_many_extension(&{
let mut v = denominator_values[range].to_vec();
v.push(*partial_product_check);
v
});
}
vanishing_partial_products_terms.extend(partial_product_checks);
let quotient = builder.mul_many_extension(&{
let mut v = quotient_values[final_num_prod..].to_vec();
let nume_acc = builder.mul_many_extension(&{
let mut v = numerator_values[final_num_prod..].to_vec();
v.push(*current_partial_products.last().unwrap());
v
});
let mut v_shift_term = builder.mul_sub_extension(quotient, z_x, z_gz);
// Need to multiply by the denominators to make sure we get a polynomial.
v_shift_term = builder.mul_many_extension(&{
let z_gz_denominators = builder.mul_many_extension(&{
let mut v = denominator_values[final_num_prod..].to_vec();
v.push(v_shift_term);
v.push(z_gz);
v
});
let v_shift_term = builder.mul_sub_extension(nume_acc, z_x, z_gz_denominators);
vanishing_v_shift_terms.push(v_shift_term);
}

View File

@ -28,18 +28,26 @@ pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) {
(num_chunks, num_chunks * chunk_size)
}
/// Checks that the partial products of `v` are coherent with those in `partials` by only computing
/// Checks that the partial products of `numerators/denominators` are coherent with those in `partials` by only computing
/// products of size `max_degree` or less.
pub fn check_partial_products<F: Field>(v: &[F], partials: &[F], max_degree: usize) -> Vec<F> {
pub fn check_partial_products<F: Field>(
numerators: &[F],
denominators: &[F],
partials: &[F],
max_degree: usize,
) -> Vec<F> {
debug_assert!(max_degree > 1);
let mut partials = partials.iter();
let mut res = Vec::new();
let mut acc = F::ONE;
let chunk_size = max_degree;
for chunk in v.chunks_exact(chunk_size) {
acc *= chunk.iter().copied().product();
let new_acc = *partials.next().unwrap();
res.push(acc - new_acc);
for (nume_chunk, deno_chunk) in numerators
.chunks_exact(chunk_size)
.zip(denominators.chunks_exact(chunk_size))
{
acc *= nume_chunk.iter().copied().product();
let mut new_acc = *partials.next().unwrap();
res.push(acc - new_acc * deno_chunk.iter().copied().product());
acc = new_acc;
}
debug_assert!(partials.next().is_none());
@ -49,7 +57,8 @@ pub fn check_partial_products<F: Field>(v: &[F], partials: &[F], max_degree: usi
pub fn check_partial_products_recursively<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
v: &[ExtensionTarget<D>],
numerators: &[ExtensionTarget<D>],
denominators: &[ExtensionTarget<D>],
partials: &[ExtensionTarget<D>],
max_degree: usize,
) -> Vec<ExtensionTarget<D>> {
@ -58,11 +67,16 @@ pub fn check_partial_products_recursively<F: RichField + Extendable<D>, const D:
let mut res = Vec::new();
let mut acc = builder.one_extension();
let chunk_size = max_degree;
for chunk in v.chunks_exact(chunk_size) {
let chunk_product = builder.mul_many_extension(chunk);
for (nume_chunk, deno_chunk) in numerators
.chunks_exact(chunk_size)
.zip(denominators.chunks_exact(chunk_size))
{
let nume_product = builder.mul_many_extension(nume_chunk);
let deno_product = builder.mul_many_extension(deno_chunk);
let new_acc = *partials.next().unwrap();
// Assert that new_acc = acc * chunk_product.
res.push(builder.mul_sub_extension(acc, chunk_product, new_acc));
let new_acc_deno = builder.mul_extension(new_acc, deno_product);
// Assert that new_acc*deno_product = acc * nume_product.
res.push(builder.mul_sub_extension(acc, nume_product, new_acc_deno));
acc = new_acc;
}
debug_assert!(partials.next().is_none());
@ -78,6 +92,7 @@ mod tests {
#[test]
fn test_partial_products() {
type F = GoldilocksField;
let denominators = vec![F::ONE; 6];
let v = [1, 2, 3, 4, 5, 6]
.into_iter()
.map(|&i| F::from_canonical_u64(i))
@ -93,7 +108,7 @@ mod tests {
let nums = num_partial_products(v.len(), 2);
assert_eq!(p.len(), nums.0);
assert!(check_partial_products(&v, &p, 2)
assert!(check_partial_products(&v, &denominators, &p, 2)
.iter()
.all(|x| x.is_zero()));
assert_eq!(
@ -115,7 +130,7 @@ mod tests {
);
let nums = num_partial_products(v.len(), 3);
assert_eq!(p.len(), nums.0);
assert!(check_partial_products(&v, &p, 3)
assert!(check_partial_products(&v, &denominators, &p, 3)
.iter()
.all(|x| x.is_zero()));
assert_eq!(