From 4e361726d0a03e8a7d0f333bf405c140f0856ddb Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 8 Nov 2021 15:50:33 +0100 Subject: [PATCH 001/202] Use partial product chain --- src/plonk/prover.rs | 3 +- src/util/partial_products.rs | 88 ++++++++++++++++++------------------ 2 files changed, 44 insertions(+), 47 deletions(-) diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 031354fa..427880c3 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -268,8 +268,7 @@ fn wires_permutation_partial_products, const D: usi let quotient_partials = partial_products("ient_values, degree); // This is the final product for the quotient. - let quotient = quotient_partials - [common_data.num_partial_products.0 - common_data.num_partial_products.1..] + let quotient = quotient_partials[common_data.num_partial_products.1..] .iter() .copied() .product(); diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 633047d0..83b0e396 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -1,5 +1,5 @@ use std::iter::Product; -use std::ops::Sub; +use std::ops::{MulAssign, Sub}; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; @@ -9,17 +9,18 @@ use crate::util::ceil_div_usize; /// Compute partial products of the original vector `v` such that all products consist of `max_degree` /// or less elements. This is done until we've computed the product `P` of all elements in the vector. -pub fn partial_products(v: &[T], max_degree: usize) -> Vec { +pub fn partial_products(v: &[T], max_degree: usize) -> Vec { + debug_assert!(max_degree > 1); let mut res = Vec::new(); - let mut remainder = v.to_vec(); - while remainder.len() > max_degree { - let new_partials = remainder - .chunks(max_degree) - // TODO: can filter out chunks of length 1. - .map(|chunk| chunk.iter().copied().product()) - .collect::>(); - res.extend_from_slice(&new_partials); - remainder = new_partials; + let mut acc = v[0]; + let chunk_size = max_degree - 1; + let num_chunks = ceil_div_usize(v.len() - 1, chunk_size) - 1; + for i in 0..num_chunks { + acc *= v[1 + i * chunk_size..1 + (i + 1) * chunk_size] + .iter() + .copied() + .product(); + res.push(acc); } res @@ -29,34 +30,33 @@ pub fn partial_products(v: &[T], max_degree: usize) -> Vec /// vector of length `n`, and `b` is the number of elements needed to compute the final product. pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { debug_assert!(max_degree > 1); - let mut res = 0; - let mut remainder = n; - while remainder > max_degree { - let new_partials_len = ceil_div_usize(remainder, max_degree); - res += new_partials_len; - remainder = new_partials_len; - } + let chunk_size = max_degree - 1; + let num_chunks = ceil_div_usize(n - 1, chunk_size) - 1; - (res, remainder) + (num_chunks, 1 + num_chunks * chunk_size) } /// Checks that the partial products of `v` are coherent with those in `partials` by only computing /// products of size `max_degree` or less. -pub fn check_partial_products>( +pub fn check_partial_products>( v: &[T], mut partials: &[T], max_degree: usize, ) -> Vec { + debug_assert!(max_degree > 1); + let mut partials = partials.iter(); let mut res = Vec::new(); - let mut remainder = v; - while remainder.len() > max_degree { - let products = remainder - .chunks(max_degree) - .map(|chunk| chunk.iter().copied().product::()); - let products_len = products.len(); - res.extend(products.zip(partials).map(|(a, &b)| a - b)); - (remainder, partials) = partials.split_at(products_len); + let mut acc = v[0]; + let chunk_size = max_degree - 1; + let num_chunks = ceil_div_usize(v.len() - 1, chunk_size) - 1; + for i in 0..num_chunks { + acc *= v[1 + i * chunk_size..1 + (i + 1) * chunk_size] + .iter() + .copied() + .product(); + res.push(acc - *partials.next().unwrap()); } + debug_assert!(partials.next().is_none()); res } @@ -67,22 +67,20 @@ pub fn check_partial_products_recursively, const D: partials: &[ExtensionTarget], max_degree: usize, ) -> Vec> { + debug_assert!(max_degree > 1); + let mut partials = partials.iter(); let mut res = Vec::new(); - let mut remainder = v.to_vec(); - let mut partials = partials.to_vec(); - while remainder.len() > max_degree { - let products = remainder - .chunks(max_degree) - .map(|chunk| builder.mul_many_extension(chunk)) - .collect::>(); - res.extend( - products - .iter() - .zip(&partials) - .map(|(&a, &b)| builder.sub_extension(a, b)), - ); - remainder = partials.drain(..products.len()).collect(); + let mut acc = v[0]; + let chunk_size = max_degree - 1; + let num_chunks = ceil_div_usize(v.len() - 1, chunk_size) - 1; + for i in 0..num_chunks { + let mut chunk = v[1 + i * chunk_size..1 + (i + 1) * chunk_size].to_vec(); + chunk.push(acc); + acc = builder.mul_many_extension(&chunk); + + res.push(builder.sub_extension(acc, *partials.next().unwrap())); } + debug_assert!(partials.next().is_none()); res } @@ -97,15 +95,15 @@ mod tests { fn test_partial_products() { let v = vec![1, 2, 3, 4, 5, 6]; let p = partial_products(&v, 2); - assert_eq!(p, vec![2, 12, 30, 24, 30]); + assert_eq!(p, vec![2, 6, 24, 120]); let nums = num_partial_products(v.len(), 2); assert_eq!(p.len(), nums.0); assert!(check_partial_products(&v, &p, 2) .iter() .all(|x| x.is_zero())); assert_eq!( + *p.last().unwrap() * v[nums.1..].iter().copied().product::(), v.into_iter().product::(), - p[p.len() - nums.1..].iter().copied().product(), ); let v = vec![1, 2, 3, 4, 5, 6]; @@ -117,8 +115,8 @@ mod tests { .iter() .all(|x| x.is_zero())); assert_eq!( + *p.last().unwrap() * v[nums.1..].iter().copied().product::(), v.into_iter().product::(), - p[p.len() - nums.1..].iter().copied().product(), ); } } From bd1672cbf2ac822b24c6402a592ffba4eb0e14db Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 9 Nov 2021 13:56:19 +0100 Subject: [PATCH 002/202] Working --- src/plonk/circuit_builder.rs | 2 +- src/plonk/prover.rs | 46 +++++++++++++--- src/plonk/vanishing_poly.rs | 54 ++++++++++++------- src/util/partial_products.rs | 100 +++++++++++++++++------------------ 4 files changed, 126 insertions(+), 76 deletions(-) diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 5dcde1e0..87ea4f3a 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -777,7 +777,7 @@ impl, const D: usize> CircuitBuilder { .expect("No gates?"); let num_partial_products = - num_partial_products(self.config.num_routed_wires, quotient_degree_factor); + num_partial_products(self.config.num_routed_wires, quotient_degree_factor - 1); // TODO: This should also include an encoding of gate constraints. let circuit_digest_parts = [ diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 427880c3..b925e6de 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -17,7 +17,7 @@ use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch; use crate::plonk::vars::EvaluationVarsBase; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; -use crate::util::partial_products::partial_products; +use crate::util::partial_products::{check_partial_products, partial_products}; use crate::util::timing::TimingTree; use crate::util::{log2_ceil, transpose}; @@ -63,6 +63,7 @@ pub(crate) fn prove, const D: usize>( .map(|column| PolynomialValues::new(column.clone())) .collect() ); + let wires = wires_values.iter().map(|v| v.values[0]).collect::>(); let wires_commitment = timed!( timing, @@ -108,6 +109,33 @@ pub(crate) fn prove, const D: usize>( partial_products.iter_mut().for_each(|part| { part.remove(0); }); + // let part = partial_products[0].clone(); + // let v = part.iter().map(|v| v.values[0]).collect::>(); + // dbg!(); + // let numerator_values = (0..common_data.config.num_routed_wires) + // .map(|j| { + // let wire_value = wires[j]; + // let k_i = common_data.k_is[j]; + // let s_id = k_i; + // wire_value + s_id * betas[0] + gammas[0] + // }) + // .collect::>(); + // let denominator_values = (0..common_data.config.num_routed_wires) + // .map(|j| { + // let wire_value = wires[j]; + // let s_sigma = s_sigmas[j]; + // wire_value + s_sigma * betas[0] + gammas[0] + // }) + // .collect::>(); + // let quotient_values = (0..common_data.config.num_routed_wires) + // .map(|j| numerator_values[j] / denominator_values[j]) + // .collect::>(); + // + // // // The partial products considered for this iteration of `i`. + // // let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; + // // Check the quotient partial products. + // let mut partial_product_check = check_partial_products("ient_values, &v, quotient_degree); + // dbg!(partial_product_check); let zs_partial_products = [plonk_z_vecs, partial_products.concat()].concat(); let zs_partial_products_commitment = timed!( @@ -238,7 +266,7 @@ fn wires_permutation_partial_products, const D: usi prover_data: &ProverOnlyCircuitData, common_data: &CommonCircuitData, ) -> Vec> { - let degree = common_data.quotient_degree_factor; + let degree = common_data.quotient_degree_factor - 1; let subgroup = &prover_data.subgroup; let k_is = &common_data.k_is; let values = subgroup @@ -266,12 +294,18 @@ fn wires_permutation_partial_products, const D: usi .collect::>(); let quotient_partials = partial_products("ient_values, degree); + dbg!(check_partial_products( + "ient_values, + "ient_partials, + degree + )); // This is the final product for the quotient. - let quotient = quotient_partials[common_data.num_partial_products.1..] - .iter() - .copied() - .product(); + let quotient = *quotient_partials.last().unwrap() + * quotient_values[common_data.num_partial_products.1..] + .iter() + .copied() + .product(); // We add the quotient at the beginning of the vector to reuse them later in the computation of `Z`. [vec![quotient], quotient_partials].concat() diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 28c6a287..9bc4feb5 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -27,7 +27,7 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( gammas: &[F], alphas: &[F], ) -> Vec { - let max_degree = common_data.quotient_degree_factor; + let max_degree = common_data.quotient_degree_factor - 1; let (num_prods, final_num_prod) = common_data.num_partial_products; let constraint_terms = @@ -73,20 +73,27 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( check_partial_products("ient_values, current_partial_products, max_degree); // The first checks are of the form `q - n/d` which is a rational function not a polynomial. // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - denominator_values - .chunks(max_degree) - .zip(partial_product_check.iter_mut()) - .for_each(|(d, q)| { - *q *= d.iter().copied().product(); - }); + for (j, q) in partial_product_check.iter_mut().enumerate() { + let range = j * (max_degree - 1)..(j + 1) * (max_degree - 1); + *q *= denominator_values[range].iter().copied().product(); + } + // denominator_values + // .chunks(max_degree) + // .zip(partial_product_check.iter_mut()) + // .for_each(|(d, q)| { + // *q *= d.iter().copied().product(); + // }); vanishing_partial_products_terms.extend(partial_product_check); // The quotient final product is the product of the last `final_num_prod` elements. - let quotient: F::Extension = current_partial_products[num_prods - final_num_prod..] + let quotient: F::Extension = *current_partial_products.last().unwrap() + * quotient_values[final_num_prod..].iter().copied().product(); + let mut wanted = quotient * z_x - z_gz; + wanted *= denominator_values[final_num_prod..] .iter() .copied() .product(); - vanishing_v_shift_terms.push(quotient * z_x - z_gz); + vanishing_v_shift_terms.push(wanted); } let vanishing_terms = [ @@ -124,7 +131,7 @@ pub(crate) fn eval_vanishing_poly_base_batch, const assert_eq!(partial_products_batch.len(), n); assert_eq!(s_sigmas_batch.len(), n); - let max_degree = common_data.quotient_degree_factor; + let max_degree = common_data.quotient_degree_factor - 1; let (num_prods, final_num_prod) = common_data.num_partial_products; let num_gate_constraints = common_data.num_gate_constraints; @@ -189,20 +196,31 @@ pub(crate) fn eval_vanishing_poly_base_batch, const check_partial_products("ient_values, current_partial_products, max_degree); // The first checks are of the form `q - n/d` which is a rational function not a polynomial. // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - denominator_values - .chunks(max_degree) - .zip(partial_product_check.iter_mut()) - .for_each(|(d, q)| { - *q *= d.iter().copied().product(); - }); + for (j, q) in partial_product_check.iter_mut().enumerate() { + let range = j * (max_degree - 1)..(j + 1) * (max_degree - 1); + *q *= denominator_values[range].iter().copied().product(); + } + // denominator_values + // .chunks(max_degree) + // .zip(partial_product_check.iter_mut()) + // .for_each(|(d, q)| { + // *q *= d.iter().copied().product(); + // }); vanishing_partial_products_terms.extend(partial_product_check); // The quotient final product is the product of the last `final_num_prod` elements. - let quotient: F = current_partial_products[num_prods - final_num_prod..] + let quotient: F = *current_partial_products.last().unwrap() + * quotient_values[final_num_prod..].iter().copied().product(); + // let quotient: F = current_partial_products[num_prods - final_num_prod..] + // .iter() + // .copied() + // .product(); + let mut wanted = quotient * z_x - z_gz; + wanted *= denominator_values[final_num_prod..] .iter() .copied() .product(); - vanishing_v_shift_terms.push(quotient * z_x - z_gz); + vanishing_v_shift_terms.push(wanted); numerator_values.clear(); denominator_values.clear(); diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 83b0e396..1e361101 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -3,20 +3,20 @@ use std::ops::{MulAssign, Sub}; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; +use crate::field::field_types::{Field, RichField}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; /// Compute partial products of the original vector `v` such that all products consist of `max_degree` /// or less elements. This is done until we've computed the product `P` of all elements in the vector. -pub fn partial_products(v: &[T], max_degree: usize) -> Vec { +pub fn partial_products(v: &[F], max_degree: usize) -> Vec { debug_assert!(max_degree > 1); let mut res = Vec::new(); - let mut acc = v[0]; + let mut acc = F::ONE; let chunk_size = max_degree - 1; - let num_chunks = ceil_div_usize(v.len() - 1, chunk_size) - 1; + let num_chunks = ceil_div_usize(v.len(), chunk_size) - 1; for i in 0..num_chunks { - acc *= v[1 + i * chunk_size..1 + (i + 1) * chunk_size] + acc *= v[i * chunk_size..(i + 1) * chunk_size] .iter() .copied() .product(); @@ -31,30 +31,28 @@ pub fn partial_products(v: &[T], max_degree: usiz pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { debug_assert!(max_degree > 1); let chunk_size = max_degree - 1; - let num_chunks = ceil_div_usize(n - 1, chunk_size) - 1; + let num_chunks = ceil_div_usize(n, chunk_size) - 1; - (num_chunks, 1 + num_chunks * chunk_size) + (num_chunks, num_chunks * chunk_size) } /// Checks that the partial products of `v` are coherent with those in `partials` by only computing /// products of size `max_degree` or less. -pub fn check_partial_products>( - v: &[T], - mut partials: &[T], - max_degree: usize, -) -> Vec { +pub fn check_partial_products(v: &[F], mut partials: &[F], max_degree: usize) -> Vec { debug_assert!(max_degree > 1); let mut partials = partials.iter(); let mut res = Vec::new(); - let mut acc = v[0]; + let mut acc = F::ONE; let chunk_size = max_degree - 1; - let num_chunks = ceil_div_usize(v.len() - 1, chunk_size) - 1; + let num_chunks = ceil_div_usize(v.len(), chunk_size) - 1; for i in 0..num_chunks { - acc *= v[1 + i * chunk_size..1 + (i + 1) * chunk_size] + acc *= v[i * chunk_size..(i + 1) * chunk_size] .iter() .copied() .product(); - res.push(acc - *partials.next().unwrap()); + let bacc = *partials.next().unwrap(); + res.push(acc - bacc); + acc = bacc; } debug_assert!(partials.next().is_none()); @@ -85,38 +83,38 @@ pub fn check_partial_products_recursively, const D: res } -#[cfg(test)] -mod tests { - use num::Zero; - - use super::*; - - #[test] - fn test_partial_products() { - let v = vec![1, 2, 3, 4, 5, 6]; - let p = partial_products(&v, 2); - assert_eq!(p, vec![2, 6, 24, 120]); - let nums = num_partial_products(v.len(), 2); - assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &p, 2) - .iter() - .all(|x| x.is_zero())); - assert_eq!( - *p.last().unwrap() * v[nums.1..].iter().copied().product::(), - v.into_iter().product::(), - ); - - let v = vec![1, 2, 3, 4, 5, 6]; - let p = partial_products(&v, 3); - assert_eq!(p, vec![6, 120]); - let nums = num_partial_products(v.len(), 3); - assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &p, 3) - .iter() - .all(|x| x.is_zero())); - assert_eq!( - *p.last().unwrap() * v[nums.1..].iter().copied().product::(), - v.into_iter().product::(), - ); - } -} +// #[cfg(test)] +// mod tests { +// use num::Zero; +// +// use super::*; +// +// #[test] +// fn test_partial_products() { +// let v = vec![1, 2, 3, 4, 5, 6]; +// let p = partial_products(&v, 2); +// assert_eq!(p, vec![2, 6, 24, 120]); +// let nums = num_partial_products(v.len(), 2); +// assert_eq!(p.len(), nums.0); +// assert!(check_partial_products(&v, &p, 2) +// .iter() +// .all(|x| x.is_zero())); +// assert_eq!( +// *p.last().unwrap() * v[nums.1..].iter().copied().product::(), +// v.into_iter().product::(), +// ); +// +// let v = vec![1, 2, 3, 4, 5, 6]; +// let p = partial_products(&v, 3); +// assert_eq!(p, vec![6, 120]); +// let nums = num_partial_products(v.len(), 3); +// assert_eq!(p.len(), nums.0); +// assert!(check_partial_products(&v, &p, 3) +// .iter() +// .all(|x| x.is_zero())); +// assert_eq!( +// *p.last().unwrap() * v[nums.1..].iter().copied().product::(), +// v.into_iter().product::(), +// ); +// } +// } From 9617c221730bf960f30e640b115a43d5bd4d4c02 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 9 Nov 2021 14:24:04 +0100 Subject: [PATCH 003/202] Increase degree --- src/plonk/circuit_builder.rs | 2 +- src/plonk/prover.rs | 35 +---------------------------------- src/plonk/vanishing_poly.rs | 24 ++++-------------------- src/util/partial_products.rs | 6 +++--- 4 files changed, 9 insertions(+), 58 deletions(-) diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 87ea4f3a..5dcde1e0 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -777,7 +777,7 @@ impl, const D: usize> CircuitBuilder { .expect("No gates?"); let num_partial_products = - num_partial_products(self.config.num_routed_wires, quotient_degree_factor - 1); + num_partial_products(self.config.num_routed_wires, quotient_degree_factor); // TODO: This should also include an encoding of gate constraints. let circuit_digest_parts = [ diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index b925e6de..22f9411c 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -63,7 +63,6 @@ pub(crate) fn prove, const D: usize>( .map(|column| PolynomialValues::new(column.clone())) .collect() ); - let wires = wires_values.iter().map(|v| v.values[0]).collect::>(); let wires_commitment = timed!( timing, @@ -109,33 +108,6 @@ pub(crate) fn prove, const D: usize>( partial_products.iter_mut().for_each(|part| { part.remove(0); }); - // let part = partial_products[0].clone(); - // let v = part.iter().map(|v| v.values[0]).collect::>(); - // dbg!(); - // let numerator_values = (0..common_data.config.num_routed_wires) - // .map(|j| { - // let wire_value = wires[j]; - // let k_i = common_data.k_is[j]; - // let s_id = k_i; - // wire_value + s_id * betas[0] + gammas[0] - // }) - // .collect::>(); - // let denominator_values = (0..common_data.config.num_routed_wires) - // .map(|j| { - // let wire_value = wires[j]; - // let s_sigma = s_sigmas[j]; - // wire_value + s_sigma * betas[0] + gammas[0] - // }) - // .collect::>(); - // let quotient_values = (0..common_data.config.num_routed_wires) - // .map(|j| numerator_values[j] / denominator_values[j]) - // .collect::>(); - // - // // // The partial products considered for this iteration of `i`. - // // let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; - // // Check the quotient partial products. - // let mut partial_product_check = check_partial_products("ient_values, &v, quotient_degree); - // dbg!(partial_product_check); let zs_partial_products = [plonk_z_vecs, partial_products.concat()].concat(); let zs_partial_products_commitment = timed!( @@ -266,7 +238,7 @@ fn wires_permutation_partial_products, const D: usi prover_data: &ProverOnlyCircuitData, common_data: &CommonCircuitData, ) -> Vec> { - let degree = common_data.quotient_degree_factor - 1; + let degree = common_data.quotient_degree_factor; let subgroup = &prover_data.subgroup; let k_is = &common_data.k_is; let values = subgroup @@ -294,11 +266,6 @@ fn wires_permutation_partial_products, const D: usi .collect::>(); let quotient_partials = partial_products("ient_values, degree); - dbg!(check_partial_products( - "ient_values, - "ient_partials, - degree - )); // This is the final product for the quotient. let quotient = *quotient_partials.last().unwrap() diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 9bc4feb5..b9b4d241 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -27,7 +27,7 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( gammas: &[F], alphas: &[F], ) -> Vec { - let max_degree = common_data.quotient_degree_factor - 1; + let max_degree = common_data.quotient_degree_factor; let (num_prods, final_num_prod) = common_data.num_partial_products; let constraint_terms = @@ -74,15 +74,9 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( // The first checks are of the form `q - n/d` which is a rational function not a polynomial. // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. for (j, q) in partial_product_check.iter_mut().enumerate() { - let range = j * (max_degree - 1)..(j + 1) * (max_degree - 1); + let range = j * max_degree..(j + 1) * max_degree; *q *= denominator_values[range].iter().copied().product(); } - // denominator_values - // .chunks(max_degree) - // .zip(partial_product_check.iter_mut()) - // .for_each(|(d, q)| { - // *q *= d.iter().copied().product(); - // }); vanishing_partial_products_terms.extend(partial_product_check); // The quotient final product is the product of the last `final_num_prod` elements. @@ -131,7 +125,7 @@ pub(crate) fn eval_vanishing_poly_base_batch, const assert_eq!(partial_products_batch.len(), n); assert_eq!(s_sigmas_batch.len(), n); - let max_degree = common_data.quotient_degree_factor - 1; + let max_degree = common_data.quotient_degree_factor; let (num_prods, final_num_prod) = common_data.num_partial_products; let num_gate_constraints = common_data.num_gate_constraints; @@ -197,24 +191,14 @@ pub(crate) fn eval_vanishing_poly_base_batch, const // The first checks are of the form `q - n/d` which is a rational function not a polynomial. // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. for (j, q) in partial_product_check.iter_mut().enumerate() { - let range = j * (max_degree - 1)..(j + 1) * (max_degree - 1); + let range = j * max_degree..(j + 1) * max_degree; *q *= denominator_values[range].iter().copied().product(); } - // denominator_values - // .chunks(max_degree) - // .zip(partial_product_check.iter_mut()) - // .for_each(|(d, q)| { - // *q *= d.iter().copied().product(); - // }); vanishing_partial_products_terms.extend(partial_product_check); // The quotient final product is the product of the last `final_num_prod` elements. let quotient: F = *current_partial_products.last().unwrap() * quotient_values[final_num_prod..].iter().copied().product(); - // let quotient: F = current_partial_products[num_prods - final_num_prod..] - // .iter() - // .copied() - // .product(); let mut wanted = quotient * z_x - z_gz; wanted *= denominator_values[final_num_prod..] .iter() diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 1e361101..bc5fce45 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -13,7 +13,7 @@ pub fn partial_products(v: &[F], max_degree: usize) -> Vec { debug_assert!(max_degree > 1); let mut res = Vec::new(); let mut acc = F::ONE; - let chunk_size = max_degree - 1; + let chunk_size = max_degree; let num_chunks = ceil_div_usize(v.len(), chunk_size) - 1; for i in 0..num_chunks { acc *= v[i * chunk_size..(i + 1) * chunk_size] @@ -30,7 +30,7 @@ pub fn partial_products(v: &[F], max_degree: usize) -> Vec { /// vector of length `n`, and `b` is the number of elements needed to compute the final product. pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { debug_assert!(max_degree > 1); - let chunk_size = max_degree - 1; + let chunk_size = max_degree; let num_chunks = ceil_div_usize(n, chunk_size) - 1; (num_chunks, num_chunks * chunk_size) @@ -43,7 +43,7 @@ pub fn check_partial_products(v: &[F], mut partials: &[F], max_degree: let mut partials = partials.iter(); let mut res = Vec::new(); let mut acc = F::ONE; - let chunk_size = max_degree - 1; + let chunk_size = max_degree; let num_chunks = ceil_div_usize(v.len(), chunk_size) - 1; for i in 0..num_chunks { acc *= v[i * chunk_size..(i + 1) * chunk_size] From 7cf965ded511011c39f0126a69e57d308f39cdec Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 9 Nov 2021 15:18:43 +0100 Subject: [PATCH 004/202] All tests pass --- src/plonk/vanishing_poly.rs | 37 +++++++++++++++++++++++++++--------- src/util/partial_products.rs | 25 ++++++++++++------------ 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index b9b4d241..7b9d0c6c 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -396,20 +396,39 @@ pub(crate) fn eval_vanishing_poly_recursively, cons ); // The first checks are of the form `q - n/d` which is a rational function not a polynomial. // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - denominator_values - .chunks(max_degree) - .zip(partial_product_check.iter_mut()) - .for_each(|(d, q)| { - let mut v = d.to_vec(); + // denominator_values + // .chunks(max_degree) + // .zip(partial_product_check.iter_mut()) + // .for_each(|(d, q)| { + // let mut v = d.to_vec(); + // v.push(*q); + // *q = builder.mul_many_extension(&v); + // }); + for (j, q) in partial_product_check.iter_mut().enumerate() { + let range = j * max_degree..(j + 1) * max_degree; + *q = builder.mul_many_extension(&{ + let mut v = denominator_values[range].to_vec(); v.push(*q); - *q = builder.mul_many_extension(&v); + v }); + } vanishing_partial_products_terms.extend(partial_product_check); // The quotient final product is the product of the last `final_num_prod` elements. - let quotient = - builder.mul_many_extension(¤t_partial_products[num_prods - final_num_prod..]); - vanishing_v_shift_terms.push(builder.mul_sub_extension(quotient, z_x, z_gz)); + // let quotient = + // builder.mul_many_extension(¤t_partial_products[num_prods - final_num_prod..]); + let quotient = builder.mul_many_extension(&{ + let mut v = quotient_values[final_num_prod..].to_vec(); + v.push(*current_partial_products.last().unwrap()); + v + }); + let mut wanted = builder.mul_sub_extension(quotient, z_x, z_gz); + wanted = builder.mul_many_extension(&{ + let mut v = denominator_values[final_num_prod..].to_vec(); + v.push(wanted); + v + }); + vanishing_v_shift_terms.push(wanted); } let vanishing_terms = [ diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index bc5fce45..398ec35f 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -14,7 +14,7 @@ pub fn partial_products(v: &[F], max_degree: usize) -> Vec { let mut res = Vec::new(); let mut acc = F::ONE; let chunk_size = max_degree; - let num_chunks = ceil_div_usize(v.len(), chunk_size) - 1; + let num_chunks = v.len() / chunk_size; for i in 0..num_chunks { acc *= v[i * chunk_size..(i + 1) * chunk_size] .iter() @@ -31,7 +31,7 @@ pub fn partial_products(v: &[F], max_degree: usize) -> Vec { pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { debug_assert!(max_degree > 1); let chunk_size = max_degree; - let num_chunks = ceil_div_usize(n, chunk_size) - 1; + let num_chunks = n / chunk_size; (num_chunks, num_chunks * chunk_size) } @@ -44,15 +44,15 @@ pub fn check_partial_products(v: &[F], mut partials: &[F], max_degree: let mut res = Vec::new(); let mut acc = F::ONE; let chunk_size = max_degree; - let num_chunks = ceil_div_usize(v.len(), chunk_size) - 1; + let num_chunks = v.len() / chunk_size; for i in 0..num_chunks { acc *= v[i * chunk_size..(i + 1) * chunk_size] .iter() .copied() .product(); - let bacc = *partials.next().unwrap(); - res.push(acc - bacc); - acc = bacc; + let new_acc = *partials.next().unwrap(); + res.push(acc - new_acc); + acc = new_acc; } debug_assert!(partials.next().is_none()); @@ -68,15 +68,16 @@ pub fn check_partial_products_recursively, const D: debug_assert!(max_degree > 1); let mut partials = partials.iter(); let mut res = Vec::new(); - let mut acc = v[0]; - let chunk_size = max_degree - 1; - let num_chunks = ceil_div_usize(v.len() - 1, chunk_size) - 1; + let mut acc = builder.one_extension(); + let chunk_size = max_degree; + let num_chunks = v.len() / chunk_size; for i in 0..num_chunks { - let mut chunk = v[1 + i * chunk_size..1 + (i + 1) * chunk_size].to_vec(); + let mut chunk = v[i * chunk_size..(i + 1) * chunk_size].to_vec(); chunk.push(acc); acc = builder.mul_many_extension(&chunk); - - res.push(builder.sub_extension(acc, *partials.next().unwrap())); + let new_acc = *partials.next().unwrap(); + res.push(builder.sub_extension(acc, new_acc)); + acc = new_acc; } debug_assert!(partials.next().is_none()); From abc706ee26a94fe95767703950612d89eb8e8a13 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 9 Nov 2021 17:18:15 +0100 Subject: [PATCH 005/202] Fix partial product test --- src/util/partial_products.rs | 93 ++++++++++++++++++++++-------------- 1 file changed, 57 insertions(+), 36 deletions(-) diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 398ec35f..38177de8 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -27,7 +27,7 @@ pub fn partial_products(v: &[F], max_degree: usize) -> Vec { } /// Returns a tuple `(a,b)`, where `a` is the length of the output of `partial_products()` on a -/// vector of length `n`, and `b` is the number of elements needed to compute the final product. +/// vector of length `n`, and `b` is the number of original elements consumed in `partial_products()`. pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { debug_assert!(max_degree > 1); let chunk_size = max_degree; @@ -84,38 +84,59 @@ pub fn check_partial_products_recursively, const D: res } -// #[cfg(test)] -// mod tests { -// use num::Zero; -// -// use super::*; -// -// #[test] -// fn test_partial_products() { -// let v = vec![1, 2, 3, 4, 5, 6]; -// let p = partial_products(&v, 2); -// assert_eq!(p, vec![2, 6, 24, 120]); -// let nums = num_partial_products(v.len(), 2); -// assert_eq!(p.len(), nums.0); -// assert!(check_partial_products(&v, &p, 2) -// .iter() -// .all(|x| x.is_zero())); -// assert_eq!( -// *p.last().unwrap() * v[nums.1..].iter().copied().product::(), -// v.into_iter().product::(), -// ); -// -// let v = vec![1, 2, 3, 4, 5, 6]; -// let p = partial_products(&v, 3); -// assert_eq!(p, vec![6, 120]); -// let nums = num_partial_products(v.len(), 3); -// assert_eq!(p.len(), nums.0); -// assert!(check_partial_products(&v, &p, 3) -// .iter() -// .all(|x| x.is_zero())); -// assert_eq!( -// *p.last().unwrap() * v[nums.1..].iter().copied().product::(), -// v.into_iter().product::(), -// ); -// } -// } +#[cfg(test)] +mod tests { + use num::Zero; + + use super::*; + use crate::field::goldilocks_field::GoldilocksField; + + #[test] + fn test_partial_products() { + type F = GoldilocksField; + let v = [1, 2, 3, 4, 5, 6] + .into_iter() + .map(|&i| F::from_canonical_u64(i)) + .collect::>(); + let p = partial_products(&v, 2); + assert_eq!( + p, + [2, 24, 720] + .into_iter() + .map(|&i| F::from_canonical_u64(i)) + .collect::>() + ); + + let nums = num_partial_products(v.len(), 2); + assert_eq!(p.len(), nums.0); + assert!(check_partial_products(&v, &p, 2) + .iter() + .all(|x| x.is_zero())); + assert_eq!( + *p.last().unwrap() * v[nums.1..].iter().copied().product::(), + v.into_iter().product::(), + ); + + let v = [1, 2, 3, 4, 5, 6] + .into_iter() + .map(|&i| F::from_canonical_u64(i)) + .collect::>(); + let p = partial_products(&v, 3); + assert_eq!( + p, + [6, 720] + .into_iter() + .map(|&i| F::from_canonical_u64(i)) + .collect::>() + ); + let nums = num_partial_products(v.len(), 3); + assert_eq!(p.len(), nums.0); + assert!(check_partial_products(&v, &p, 3) + .iter() + .all(|x| x.is_zero())); + assert_eq!( + *p.last().unwrap() * v[nums.1..].iter().copied().product::(), + v.into_iter().product::(), + ); + } +} From 067f81e24f8b1bc15aab8ab2b75494400fcd338e Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 9 Nov 2021 17:25:22 +0100 Subject: [PATCH 006/202] Comments and cleaning --- src/plonk/vanishing_poly.rs | 74 ++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 42 deletions(-) diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 7b9d0c6c..4976eaba 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -69,25 +69,25 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the quotient partial products. - let mut partial_product_check = + let mut partial_product_checks = check_partial_products("ient_values, current_partial_products, max_degree); - // The first checks are of the form `q - n/d` which is a rational function not a polynomial. - // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - for (j, q) in partial_product_check.iter_mut().enumerate() { + // The partial products are products of quotients, so we multiply them by the product of the + // corresponding denominators to make sure they are polynomials. + for (j, partial_product_check) in partial_product_checks.iter_mut().enumerate() { let range = j * max_degree..(j + 1) * max_degree; - *q *= denominator_values[range].iter().copied().product(); + *partial_product_check *= denominator_values[range].iter().copied().product(); } - vanishing_partial_products_terms.extend(partial_product_check); + vanishing_partial_products_terms.extend(partial_product_checks); - // The quotient final product is the product of the last `final_num_prod` elements. let quotient: F::Extension = *current_partial_products.last().unwrap() * quotient_values[final_num_prod..].iter().copied().product(); - let mut wanted = quotient * z_x - z_gz; - wanted *= denominator_values[final_num_prod..] + let mut v_shift_term = quotient * z_x - z_gz; + // Need to multiply by the denominators to make sure we get a polynomial. + v_shift_term *= denominator_values[final_num_prod..] .iter() .copied() .product(); - vanishing_v_shift_terms.push(wanted); + vanishing_v_shift_terms.push(v_shift_term); } let vanishing_terms = [ @@ -186,25 +186,25 @@ pub(crate) fn eval_vanishing_poly_base_batch, const // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the numerator partial products. - let mut partial_product_check = + let mut partial_product_checks = check_partial_products("ient_values, current_partial_products, max_degree); - // The first checks are of the form `q - n/d` which is a rational function not a polynomial. - // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - for (j, q) in partial_product_check.iter_mut().enumerate() { + // The partial products are products of quotients, so we multiply them by the product of the + // corresponding denominators to make sure they are polynomials. + for (j, partial_product_check) in partial_product_checks.iter_mut().enumerate() { let range = j * max_degree..(j + 1) * max_degree; - *q *= denominator_values[range].iter().copied().product(); + *partial_product_check *= denominator_values[range].iter().copied().product(); } - vanishing_partial_products_terms.extend(partial_product_check); + vanishing_partial_products_terms.extend(partial_product_checks); - // The quotient final product is the product of the last `final_num_prod` elements. let quotient: F = *current_partial_products.last().unwrap() * quotient_values[final_num_prod..].iter().copied().product(); - let mut wanted = quotient * z_x - z_gz; - wanted *= denominator_values[final_num_prod..] + let mut v_shift_term = quotient * z_x - z_gz; + // Need to multiply by the denominators to make sure we get a polynomial. + v_shift_term *= denominator_values[final_num_prod..] .iter() .copied() .product(); - vanishing_v_shift_terms.push(wanted); + vanishing_v_shift_terms.push(v_shift_term); numerator_values.clear(); denominator_values.clear(); @@ -388,47 +388,37 @@ pub(crate) fn eval_vanishing_poly_recursively, cons // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the quotient partial products. - let mut partial_product_check = check_partial_products_recursively( + let mut partial_product_checks = check_partial_products_recursively( builder, "ient_values, current_partial_products, max_degree, ); - // The first checks are of the form `q - n/d` which is a rational function not a polynomial. - // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - // denominator_values - // .chunks(max_degree) - // .zip(partial_product_check.iter_mut()) - // .for_each(|(d, q)| { - // let mut v = d.to_vec(); - // v.push(*q); - // *q = builder.mul_many_extension(&v); - // }); - for (j, q) in partial_product_check.iter_mut().enumerate() { + // The partial products are products of quotients, so we multiply them by the product of the + // corresponding denominators to make sure they are polynomials. + for (j, partial_product_check) in partial_product_checks.iter_mut().enumerate() { let range = j * max_degree..(j + 1) * max_degree; - *q = builder.mul_many_extension(&{ + *partial_product_check = builder.mul_many_extension(&{ let mut v = denominator_values[range].to_vec(); - v.push(*q); + v.push(*partial_product_check); v }); } - vanishing_partial_products_terms.extend(partial_product_check); + vanishing_partial_products_terms.extend(partial_product_checks); - // The quotient final product is the product of the last `final_num_prod` elements. - // let quotient = - // builder.mul_many_extension(¤t_partial_products[num_prods - final_num_prod..]); let quotient = builder.mul_many_extension(&{ let mut v = quotient_values[final_num_prod..].to_vec(); v.push(*current_partial_products.last().unwrap()); v }); - let mut wanted = builder.mul_sub_extension(quotient, z_x, z_gz); - wanted = builder.mul_many_extension(&{ + let mut v_shift_term = builder.mul_sub_extension(quotient, z_x, z_gz); + // Need to multiply by the denominators to make sure we get a polynomial. + v_shift_term = builder.mul_many_extension(&{ let mut v = denominator_values[final_num_prod..].to_vec(); - v.push(wanted); + v.push(v_shift_term); v }); - vanishing_v_shift_terms.push(wanted); + vanishing_v_shift_terms.push(v_shift_term); } let vanishing_terms = [ From 3717ff701e9fd0540c10cba9c806e4549de9bce7 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 9 Nov 2021 17:33:14 +0100 Subject: [PATCH 007/202] Minor --- src/plonk/circuit_data.rs | 4 ++-- src/plonk/prover.rs | 2 +- src/util/partial_products.rs | 8 +------- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 9ba05a87..869543af 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -190,8 +190,8 @@ pub struct CommonCircuitData, const D: usize> { /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. pub(crate) k_is: Vec, - /// The number of partial products needed to compute the `Z` polynomials and the number - /// of partial products needed to compute the final product. + /// The number of partial products needed to compute the `Z` polynomials and + /// the number of original elements consumed in `partial_products()`. pub(crate) num_partial_products: (usize, usize), /// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 22f9411c..6c57217c 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -17,7 +17,7 @@ use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch; use crate::plonk::vars::EvaluationVarsBase; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; -use crate::util::partial_products::{check_partial_products, partial_products}; +use crate::util::partial_products::partial_products; use crate::util::timing::TimingTree; use crate::util::{log2_ceil, transpose}; diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 38177de8..1b3821b6 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -1,11 +1,7 @@ -use std::iter::Product; -use std::ops::{MulAssign, Sub}; - use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::util::ceil_div_usize; /// Compute partial products of the original vector `v` such that all products consist of `max_degree` /// or less elements. This is done until we've computed the product `P` of all elements in the vector. @@ -38,7 +34,7 @@ pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { /// Checks that the partial products of `v` are coherent with those in `partials` by only computing /// products of size `max_degree` or less. -pub fn check_partial_products(v: &[F], mut partials: &[F], max_degree: usize) -> Vec { +pub fn check_partial_products(v: &[F], partials: &[F], max_degree: usize) -> Vec { debug_assert!(max_degree > 1); let mut partials = partials.iter(); let mut res = Vec::new(); @@ -86,8 +82,6 @@ pub fn check_partial_products_recursively, const D: #[cfg(test)] mod tests { - use num::Zero; - use super::*; use crate::field::goldilocks_field::GoldilocksField; From 168f572804a45528424f63510fe157b2d16db282 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Tue, 9 Nov 2021 14:52:05 -0800 Subject: [PATCH 008/202] Fix rustfmt failures on main (#348) --- benches/hashing.rs | 2 +- src/field/packed_avx2/mod.rs | 14 +++++++------- src/gadgets/hash.rs | 6 +++--- src/gates/poseidon.rs | 10 +++++----- src/gates/poseidon_mds.rs | 10 +++++----- src/hash/poseidon.rs | 6 +++--- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/benches/hashing.rs b/benches/hashing.rs index 5669e50b..c229972e 100644 --- a/benches/hashing.rs +++ b/benches/hashing.rs @@ -19,7 +19,7 @@ pub(crate) fn bench_gmimc, const WIDTH: usize>(c: &mut Criterion pub(crate) fn bench_poseidon, const WIDTH: usize>(c: &mut Criterion) where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { c.bench_function(&format!("poseidon<{}, {}>", type_name::(), WIDTH), |b| { b.iter_batched( diff --git a/src/field/packed_avx2/mod.rs b/src/field/packed_avx2/mod.rs index eddbb5c9..20eecba7 100644 --- a/src/field/packed_avx2/mod.rs +++ b/src/field/packed_avx2/mod.rs @@ -34,7 +34,7 @@ mod tests { fn test_add() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); @@ -52,7 +52,7 @@ mod tests { fn test_mul() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); @@ -70,7 +70,7 @@ mod tests { fn test_square() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); @@ -86,7 +86,7 @@ mod tests { fn test_neg() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); @@ -102,7 +102,7 @@ mod tests { fn test_sub() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); @@ -120,7 +120,7 @@ mod tests { fn test_interleave_is_involution() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); @@ -144,7 +144,7 @@ mod tests { fn test_interleave() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let in_a: [F; 4] = [ F::from_noncanonical_u64(00), diff --git a/src/gadgets/hash.rs b/src/gadgets/hash.rs index 99da9e1e..db4cb1e8 100644 --- a/src/gadgets/hash.rs +++ b/src/gadgets/hash.rs @@ -15,7 +15,7 @@ impl, const D: usize> CircuitBuilder { pub fn permute(&mut self, inputs: [Target; W]) -> [Target; W] where F: GMiMC + Poseidon, - [(); W - 1]: , + [(); W - 1]:, { // We don't want to swap any inputs, so set that wire to 0. let _false = self._false(); @@ -31,7 +31,7 @@ impl, const D: usize> CircuitBuilder { ) -> [Target; W] where F: GMiMC + Poseidon, - [(); W - 1]: , + [(); W - 1]:, { match HASH_FAMILY { HashFamily::GMiMC => self.gmimc_permute_swapped(inputs, swap), @@ -88,7 +88,7 @@ impl, const D: usize> CircuitBuilder { ) -> [Target; W] where F: Poseidon, - [(); W - 1]: , + [(); W - 1]:, { let gate_type = PoseidonGate::::new(); let gate = self.add_gate(gate_type, vec![]); diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 1f5f746d..6e1eb69a 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -26,7 +26,7 @@ pub struct PoseidonGate< const D: usize, const WIDTH: usize, > where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { _phantom: PhantomData, } @@ -34,7 +34,7 @@ pub struct PoseidonGate< impl + Poseidon, const D: usize, const WIDTH: usize> PoseidonGate where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { pub fn new() -> Self { PoseidonGate { @@ -91,7 +91,7 @@ where impl + Poseidon, const D: usize, const WIDTH: usize> Gate for PoseidonGate where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { fn id(&self) -> String { format!("{:?}", self, WIDTH) @@ -396,7 +396,7 @@ struct PoseidonGenerator< const D: usize, const WIDTH: usize, > where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { gate_index: usize, _phantom: PhantomData, @@ -405,7 +405,7 @@ struct PoseidonGenerator< impl + Poseidon, const D: usize, const WIDTH: usize> SimpleGenerator for PoseidonGenerator where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { fn dependencies(&self) -> Vec { (0..WIDTH) diff --git a/src/gates/poseidon_mds.rs b/src/gates/poseidon_mds.rs index 8a42b588..a127df68 100644 --- a/src/gates/poseidon_mds.rs +++ b/src/gates/poseidon_mds.rs @@ -21,7 +21,7 @@ pub struct PoseidonMdsGate< const D: usize, const WIDTH: usize, > where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { _phantom: PhantomData, } @@ -29,7 +29,7 @@ pub struct PoseidonMdsGate< impl + Poseidon, const D: usize, const WIDTH: usize> PoseidonMdsGate where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { pub fn new() -> Self { PoseidonMdsGate { @@ -116,7 +116,7 @@ where impl + Poseidon, const D: usize, const WIDTH: usize> Gate for PoseidonMdsGate where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { fn id(&self) -> String { format!("{:?}", self, WIDTH) @@ -207,7 +207,7 @@ where #[derive(Clone, Debug)] struct PoseidonMdsGenerator where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { gate_index: usize, } @@ -215,7 +215,7 @@ where impl + Poseidon, const D: usize, const WIDTH: usize> SimpleGenerator for PoseidonMdsGenerator where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { fn dependencies(&self) -> Vec { (0..WIDTH) diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 9a52060c..9e4dd7f4 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -147,7 +147,7 @@ pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ pub trait Poseidon: PrimeField where // magic to get const generic expressions to work - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { // Total number of round constants required: width of the input // times number of rounds. @@ -634,7 +634,7 @@ pub(crate) mod test_helpers { test_vectors: Vec<([u64; WIDTH], [u64; WIDTH])>, ) where F: Poseidon, - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { for (input_, expected_output_) in test_vectors.into_iter() { let mut input = [F::ZERO; WIDTH]; @@ -652,7 +652,7 @@ pub(crate) mod test_helpers { pub(crate) fn check_consistency() where F: Poseidon, - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { let mut input = [F::ZERO; WIDTH]; for i in 0..WIDTH { From 9711127599e2678771c6aca156d3c7667fd6a348 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Tue, 9 Nov 2021 15:14:41 -0800 Subject: [PATCH 009/202] Use Jemalloc (#347) --- Cargo.toml | 3 +++ src/lib.rs | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 71970b41..4c182fa2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,9 @@ serde = { version = "1.0", features = ["derive"] } serde_cbor = "0.11.1" static_assertions = "1.1.0" +[target.'cfg(not(target_env = "msvc"))'.dependencies] +jemallocator = "0.3.2" + [dev-dependencies] criterion = "0.3.5" tynm = "0.1.6" diff --git a/src/lib.rs b/src/lib.rs index c72f783c..3ed9f747 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,3 +13,11 @@ pub mod iop; pub mod plonk; pub mod polynomial; pub mod util; + +// Set up Jemalloc +#[cfg(not(target_env = "msvc"))] +use jemallocator::Jemalloc; + +#[cfg(not(target_env = "msvc"))] +#[global_allocator] +static GLOBAL: Jemalloc = Jemalloc; From 7054fcdaf9a51787e5b4696cf90e5a3b7607cc7c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 1 Oct 2021 14:20:21 -0700 Subject: [PATCH 010/202] initial --- src/gadgets/mod.rs | 1 + src/gadgets/nonnative.rs | 81 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 src/gadgets/nonnative.rs diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index aa18fbeb..5ec494c7 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -3,6 +3,7 @@ pub mod arithmetic_extension; pub mod hash; pub mod insert; pub mod interpolation; +pub mod nonnative; pub mod permutation; pub mod polynomial; pub mod random_access; diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs new file mode 100644 index 00000000..6763a1a5 --- /dev/null +++ b/src/gadgets/nonnative.rs @@ -0,0 +1,81 @@ +use std::collections::BTreeMap; +use std::marker::PhantomData; + +use crate::field::field_types::RichField; +use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::gates::switch::SwitchGate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::bimap::bimap_from_lists; + +pub struct U32Target(Target); + +pub struct NonNativeTarget { + /// The modulus of the field F' being represented. + modulus: BigUInt, + /// These F elements are assumed to contain 32-bit values. + limbs: Vec, +} + +impl, const D: usize> CircuitBuilder { + pub fn add_mul_u32(&mut self, x: U32Target, y: U32Target, z:U32Target) -> (U32Target, U32Target) { + + } + + pub fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { + + } + + pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { + + } + + pub fn add_nonnative(&mut self, a: NonNativeTarget, b: NonNativeTarget) -> NonNativeTarget { + let modulus = a.modulus; + let num_limbs = a.limbs.len(); + debug_assert!(b.modulus == modulus); + debug_assert!(b.limbs.len() == num_limbs); + + let mut combined_limbs = self.add_virtual_targets(num_limbs + 1); + let mut carry = self.zero(); + for i in 0..num_limbs { + let gate = ComparisonGate::new(bits, num_chunks); + let gate_index = self.add_gate(gate.clone(), vec![]); + + self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs); + self.connect(Target::wire(gate_index, gate.wire_second_input()), rhs); + } + } + + pub fn reduce_add_result(&mut self, limbs: Vec, modulus: BigUInt) -> Vec { + todo!() + } + + pub fn mul_nonnative(&mut self, a: NonNativeTarget, b: NonNativeTarget) -> NonNativeTarget { + let modulus = a.modulus; + let num_limbs = a.limbs.len(); + debug_assert!(b.modulus == modulus); + debug_assert!(b.limbs.len() == num_limbs); + + let mut combined_limbs = self.add_virtual_targets(2 * num_limbs - 1); + for i in 0..num_limbs { + for j in 0..num_limbs { + let sum = builder.add(a.limbs[i], b.limbs[j]); + combined_limbs[i + j] = builder.add(combined_limbs[i + j], sum); + } + } + + let reduced_limbs = self.reduce(combined_limbs, modulus); + + NonNativeTarget { + modulus, + limbs: reduced_limbs, + } + } + + pub fn reduce_mul_result(&mut self, limbs: Vec, modulus: BigUInt) -> Vec { + todo!() + } +} From d334a924b4f6e793c47b0a7c30f4d7cd6bbb0e8c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 9 Nov 2021 18:10:47 -0800 Subject: [PATCH 011/202] merge new circuit builder stuff --- src/gadgets/arithmetic_u32.rs | 67 +++++++++++++++++++++++++++++++++++ src/gadgets/mod.rs | 1 + src/gadgets/nonnative.rs | 24 +++---------- src/gates/arithmetic_u32.rs | 2 +- src/plonk/circuit_builder.rs | 4 +++ 5 files changed, 77 insertions(+), 21 deletions(-) create mode 100644 src/gadgets/arithmetic_u32.rs diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs new file mode 100644 index 00000000..4a8cff42 --- /dev/null +++ b/src/gadgets/arithmetic_u32.rs @@ -0,0 +1,67 @@ +use std::collections::BTreeMap; +use std::marker::PhantomData; + +use crate::field::field_types::RichField; +use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::gates::arithmetic_u32::U32ArithmeticGate; +use crate::gates::switch::SwitchGate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::bimap::bimap_from_lists; + +pub struct U32Target(Target); + +impl, const D: usize> CircuitBuilder { + pub fn add_virtual_u32_target(&self) -> U32Target { + U32Target(self.add_virtual_target()) + } + + pub fn zero_u32(&self) -> U32Target { + U32Target(self.zero()) + } + + pub fn one_u32(&self) -> U32Target { + U32Target(self.one()) + } + + pub fn add_mul_u32( + &mut self, + x: U32Target, + y: U32Target, + z: U32Target, + ) -> (U32Target, U32Target) { + let (gate_index, copy) = match self.current_u32_arithmetic_gate { + None => { + let gate = U32ArithmeticGate { + _phantom: PhantomData, + }; + let gate_index = self.add_gate(gate.clone(), vec![]); + (gate_index, 0) + }, + Some((gate_index, copy) => (gate_index, copy), + }; + + let output_low = self.add_virtual_u32_target(); + let output_high = self.add_virtual_u32_target(); + + self.connect(Target::wire(gate_index, gate.wire_ith_multiplicand_0(copy)), x); + self.connect(Target::wire(gate_index, gate.wire_ith_multiplicand_1(copy)), y); + self.connect(Target::wire(gate_index, gate.wire_ith_addend(copy)), z); + self.connect(Target::wire(gate_index, gate.wire_ith_output_low_half(copy)), output_low); + self.connect(Target::wire(gate_index, gate.wire_ith_output_high_half(copy)), output_high); + + self.current_u32_arithmetic_gate = Some((gate_index, 0)); + + (output_low, output_high) + } + + pub fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { + self.add_mul_u32(a, self.one_u32(), b) + } + + pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { + self.add_mul_u32(a, b, self.zero_u32()) + } +} diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 5ec494c7..e38646e3 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -1,5 +1,6 @@ pub mod arithmetic; pub mod arithmetic_extension; +pub mod arithmetic_u32; pub mod hash; pub mod insert; pub mod interpolation; diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 6763a1a5..1407360e 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -3,6 +3,8 @@ use std::marker::PhantomData; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::switch::SwitchGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; @@ -10,8 +12,6 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::bimap::bimap_from_lists; -pub struct U32Target(Target); - pub struct NonNativeTarget { /// The modulus of the field F' being represented. modulus: BigUInt, @@ -20,18 +20,6 @@ pub struct NonNativeTarget { } impl, const D: usize> CircuitBuilder { - pub fn add_mul_u32(&mut self, x: U32Target, y: U32Target, z:U32Target) -> (U32Target, U32Target) { - - } - - pub fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { - - } - - pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { - - } - pub fn add_nonnative(&mut self, a: NonNativeTarget, b: NonNativeTarget) -> NonNativeTarget { let modulus = a.modulus; let num_limbs = a.limbs.len(); @@ -41,11 +29,7 @@ impl, const D: usize> CircuitBuilder { let mut combined_limbs = self.add_virtual_targets(num_limbs + 1); let mut carry = self.zero(); for i in 0..num_limbs { - let gate = ComparisonGate::new(bits, num_chunks); - let gate_index = self.add_gate(gate.clone(), vec![]); - - self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs); - self.connect(Target::wire(gate_index, gate.wire_second_input()), rhs); + } } @@ -68,7 +52,7 @@ impl, const D: usize> CircuitBuilder { } let reduced_limbs = self.reduce(combined_limbs, modulus); - + NonNativeTarget { modulus, limbs: reduced_limbs, diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs index 6564a876..c05cf72f 100644 --- a/src/gates/arithmetic_u32.rs +++ b/src/gates/arithmetic_u32.rs @@ -17,7 +17,7 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; pub const NUM_U32_ARITHMETIC_OPS: usize = 3; /// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand). -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct U32ArithmeticGate, const D: usize> { _phantom: PhantomData, } diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 5dcde1e0..e6020ad8 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -87,6 +87,9 @@ pub struct CircuitBuilder, const D: usize> { // of switches pub(crate) current_switch_gates: Vec, usize, usize)>>, + // The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one) + pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>, + /// An available `ConstantGate` instance, if any. free_constant: Option<(usize, usize)>, } @@ -109,6 +112,7 @@ impl, const D: usize> CircuitBuilder { free_arithmetic: HashMap::new(), free_random_access: HashMap::new(), current_switch_gates: Vec::new(), + current_u32_arithmetic_gate: None, free_constant: None, }; builder.check_config(); From ffb544e4a543be7e9fcad96a3645dd3d79942e96 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 4 Oct 2021 14:17:19 -0700 Subject: [PATCH 012/202] initial non-native add --- src/gadgets/arithmetic_u32.rs | 23 +++++++++++++++++------ src/gadgets/nonnative.rs | 22 +++++++++++++++------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index 4a8cff42..d4ff5896 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -18,6 +18,10 @@ impl, const D: usize> CircuitBuilder { U32Target(self.add_virtual_target()) } + pub fn add_virtual_u32_targets(&self, n: usize) -> Vec { + self.add_virtual_targets(n).iter().cloned().map(U32Target).collect() + } + pub fn zero_u32(&self) -> U32Target { U32Target(self.zero()) } @@ -40,17 +44,17 @@ impl, const D: usize> CircuitBuilder { let gate_index = self.add_gate(gate.clone(), vec![]); (gate_index, 0) }, - Some((gate_index, copy) => (gate_index, copy), + Some((gate_index, copy)) => (gate_index, copy), }; let output_low = self.add_virtual_u32_target(); let output_high = self.add_virtual_u32_target(); - self.connect(Target::wire(gate_index, gate.wire_ith_multiplicand_0(copy)), x); - self.connect(Target::wire(gate_index, gate.wire_ith_multiplicand_1(copy)), y); - self.connect(Target::wire(gate_index, gate.wire_ith_addend(copy)), z); - self.connect(Target::wire(gate_index, gate.wire_ith_output_low_half(copy)), output_low); - self.connect(Target::wire(gate_index, gate.wire_ith_output_high_half(copy)), output_high); + self.connect(Target::wire(gate_index, U32ArithmeticGate::::wire_ith_multiplicand_0(copy)), x.0); + self.connect(Target::wire(gate_index, U32ArithmeticGate::::wire_ith_multiplicand_1(copy)), y.0); + self.connect(Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(copy)), z.0); + self.connect(Target::wire(gate_index, U32ArithmeticGate::::wire_ith_output_low_half(copy)), output_low.0); + self.connect(Target::wire(gate_index, U32ArithmeticGate::::wire_ith_output_high_half(copy)), output_high.0); self.current_u32_arithmetic_gate = Some((gate_index, 0)); @@ -61,6 +65,13 @@ impl, const D: usize> CircuitBuilder { self.add_mul_u32(a, self.one_u32(), b) } + pub fn add_three_u32(&mut self, a: U32Target, b: U32Target, c: U32Target) -> (U32Target, U32Target) { + let (init_low, carry1) = self.add_u32(a, b); + let (final_low, carry2) = self.add_u32(c, init_low); + let (combined_carry, _zero) = self.add_u32(carry1, carry2); + (final_low, combined_carry) + } + pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { self.add_mul_u32(a, b, self.zero_u32()) } diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 1407360e..01d60c85 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -1,3 +1,4 @@ +use num::bigint::BigUint; use std::collections::BTreeMap; use std::marker::PhantomData; @@ -14,7 +15,7 @@ use crate::util::bimap::bimap_from_lists; pub struct NonNativeTarget { /// The modulus of the field F' being represented. - modulus: BigUInt, + modulus: BigUint, /// These F elements are assumed to contain 32-bit values. limbs: Vec, } @@ -26,14 +27,21 @@ impl, const D: usize> CircuitBuilder { debug_assert!(b.modulus == modulus); debug_assert!(b.limbs.len() == num_limbs); - let mut combined_limbs = self.add_virtual_targets(num_limbs + 1); - let mut carry = self.zero(); + let mut combined_limbs = self.add_virtual_u32_targets(num_limbs + 1); + let mut carry = self.zero_u32(); for i in 0..num_limbs { - + let (new_limb, carry) = self.add_three_u32(carry, a.limbs[i], b.limbs[i]); + combined_limbs[i] = new_limb; + } + combined_limbs[num_limbs] = carry; + + NonNativeTarget { + modulus, + limbs: combined_limbs, } } - pub fn reduce_add_result(&mut self, limbs: Vec, modulus: BigUInt) -> Vec { + pub fn reduce_add_result(&mut self, limbs: Vec, modulus: BigUint) -> Vec { todo!() } @@ -46,8 +54,8 @@ impl, const D: usize> CircuitBuilder { let mut combined_limbs = self.add_virtual_targets(2 * num_limbs - 1); for i in 0..num_limbs { for j in 0..num_limbs { - let sum = builder.add(a.limbs[i], b.limbs[j]); - combined_limbs[i + j] = builder.add(combined_limbs[i + j], sum); + let sum = self.add(a.limbs[i], b.limbs[j]); + combined_limbs[i + j] = self.add(combined_limbs[i + j], sum); } } From e48e0a4a58e90231d6b7860be38aa7e01121863c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 4 Oct 2021 14:17:28 -0700 Subject: [PATCH 013/202] fmt --- src/gadgets/arithmetic_u32.rs | 52 +++++++++++++++++++++++++++++------ src/gadgets/nonnative.rs | 5 ++-- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index d4ff5896..d0fd195f 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -19,7 +19,11 @@ impl, const D: usize> CircuitBuilder { } pub fn add_virtual_u32_targets(&self, n: usize) -> Vec { - self.add_virtual_targets(n).iter().cloned().map(U32Target).collect() + self.add_virtual_targets(n) + .iter() + .cloned() + .map(U32Target) + .collect() } pub fn zero_u32(&self) -> U32Target { @@ -43,18 +47,45 @@ impl, const D: usize> CircuitBuilder { }; let gate_index = self.add_gate(gate.clone(), vec![]); (gate_index, 0) - }, + } Some((gate_index, copy)) => (gate_index, copy), }; let output_low = self.add_virtual_u32_target(); let output_high = self.add_virtual_u32_target(); - self.connect(Target::wire(gate_index, U32ArithmeticGate::::wire_ith_multiplicand_0(copy)), x.0); - self.connect(Target::wire(gate_index, U32ArithmeticGate::::wire_ith_multiplicand_1(copy)), y.0); - self.connect(Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(copy)), z.0); - self.connect(Target::wire(gate_index, U32ArithmeticGate::::wire_ith_output_low_half(copy)), output_low.0); - self.connect(Target::wire(gate_index, U32ArithmeticGate::::wire_ith_output_high_half(copy)), output_high.0); + self.connect( + Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_multiplicand_0(copy), + ), + x.0, + ); + self.connect( + Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_multiplicand_1(copy), + ), + y.0, + ); + self.connect( + Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(copy)), + z.0, + ); + self.connect( + Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_output_low_half(copy), + ), + output_low.0, + ); + self.connect( + Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_output_high_half(copy), + ), + output_high.0, + ); self.current_u32_arithmetic_gate = Some((gate_index, 0)); @@ -65,7 +96,12 @@ impl, const D: usize> CircuitBuilder { self.add_mul_u32(a, self.one_u32(), b) } - pub fn add_three_u32(&mut self, a: U32Target, b: U32Target, c: U32Target) -> (U32Target, U32Target) { + pub fn add_three_u32( + &mut self, + a: U32Target, + b: U32Target, + c: U32Target, + ) -> (U32Target, U32Target) { let (init_low, carry1) = self.add_u32(a, b); let (final_low, carry2) = self.add_u32(c, init_low); let (combined_carry, _zero) = self.add_u32(carry1, carry2); diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 01d60c85..725b43a4 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -1,7 +1,8 @@ -use num::bigint::BigUint; use std::collections::BTreeMap; use std::marker::PhantomData; +use num::bigint::BigUint; + use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gadgets::arithmetic_u32::U32Target; @@ -34,7 +35,7 @@ impl, const D: usize> CircuitBuilder { combined_limbs[i] = new_limb; } combined_limbs[num_limbs] = carry; - + NonNativeTarget { modulus, limbs: combined_limbs, From f71adac40b3beb29a1c37d37dcaad05da840b6df Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 4 Oct 2021 14:18:32 -0700 Subject: [PATCH 014/202] fix --- src/gadgets/nonnative.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 725b43a4..da82ee9e 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -36,9 +36,10 @@ impl, const D: usize> CircuitBuilder { } combined_limbs[num_limbs] = carry; + let reduced_limbs = self.reduce_add_result(combined_limbs, modulus); NonNativeTarget { modulus, - limbs: combined_limbs, + limbs: reduced_limbs, } } @@ -60,7 +61,7 @@ impl, const D: usize> CircuitBuilder { } } - let reduced_limbs = self.reduce(combined_limbs, modulus); + let reduced_limbs = self.reduce_mul_result(combined_limbs, modulus); NonNativeTarget { modulus, @@ -68,7 +69,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn reduce_mul_result(&mut self, limbs: Vec, modulus: BigUInt) -> Vec { + pub fn reduce_mul_result(&mut self, limbs: Vec, modulus: BigUint) -> Vec { todo!() } } From 34eacdada6ce6be8d04041346f37be370d790054 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 4 Oct 2021 16:23:21 -0700 Subject: [PATCH 015/202] progress --- src/gadgets/arithmetic_u32.rs | 21 +++++++++++++++------ src/gadgets/nonnative.rs | 6 +++--- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index d0fd195f..c5289e7c 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; -use crate::gates::arithmetic_u32::U32ArithmeticGate; +use crate::gates::arithmetic_u32::{NUM_U32_ARITHMETIC_OPS, U32ArithmeticGate}; use crate::gates::switch::SwitchGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; @@ -34,7 +34,8 @@ impl, const D: usize> CircuitBuilder { U32Target(self.one()) } - pub fn add_mul_u32( + // Returns x * y + z. + pub fn mul_add_u32( &mut self, x: U32Target, y: U32Target, @@ -45,7 +46,7 @@ impl, const D: usize> CircuitBuilder { let gate = U32ArithmeticGate { _phantom: PhantomData, }; - let gate_index = self.add_gate(gate.clone(), vec![]); + let gate_index = self.add_gate(gate, vec![]); (gate_index, 0) } Some((gate_index, copy)) => (gate_index, copy), @@ -87,13 +88,21 @@ impl, const D: usize> CircuitBuilder { output_high.0, ); - self.current_u32_arithmetic_gate = Some((gate_index, 0)); + if copy == NUM_U32_ARITHMETIC_OPS - 1 { + let gate = U32ArithmeticGate { + _phantom: PhantomData, + }; + let gate_index = self.add_gate(gate, vec![]); + self.current_u32_arithmetic_gate = Some((gate_index, 0)); + } else { + self.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); + } (output_low, output_high) } pub fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { - self.add_mul_u32(a, self.one_u32(), b) + self.mul_add_u32(a, self.one_u32(), b) } pub fn add_three_u32( @@ -109,6 +118,6 @@ impl, const D: usize> CircuitBuilder { } pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { - self.add_mul_u32(a, b, self.zero_u32()) + self.mul_add_u32(a, b, self.zero_u32()) } } diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index da82ee9e..6da8ab5e 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -43,7 +43,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn reduce_add_result(&mut self, limbs: Vec, modulus: BigUint) -> Vec { + pub fn reduce_add_result(&mut self, limbs: Vec, modulus: BigUint) -> Vec { todo!() } @@ -56,7 +56,7 @@ impl, const D: usize> CircuitBuilder { let mut combined_limbs = self.add_virtual_targets(2 * num_limbs - 1); for i in 0..num_limbs { for j in 0..num_limbs { - let sum = self.add(a.limbs[i], b.limbs[j]); + let sum = self.add_u32(a.limbs[i], b.limbs[j]); combined_limbs[i + j] = self.add(combined_limbs[i + j], sum); } } @@ -69,7 +69,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn reduce_mul_result(&mut self, limbs: Vec, modulus: BigUint) -> Vec { + pub fn reduce_mul_result(&mut self, limbs: Vec, modulus: BigUint) -> Vec { todo!() } } From 6b294c1d97b33784600ce9a029a836255db6eb79 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 4 Oct 2021 16:23:34 -0700 Subject: [PATCH 016/202] fmt --- src/gadgets/arithmetic_u32.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index c5289e7c..7c5fe218 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; -use crate::gates::arithmetic_u32::{NUM_U32_ARITHMETIC_OPS, U32ArithmeticGate}; +use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::gates::switch::SwitchGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; From ff943138f325621b2b97aaddf429d690555a56e0 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 10 Nov 2021 09:38:47 +0100 Subject: [PATCH 017/202] Apply suggestions from code review Co-authored-by: Daniel Lubarov --- src/util/partial_products.rs | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 1b3821b6..37b51825 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -10,12 +10,8 @@ pub fn partial_products(v: &[F], max_degree: usize) -> Vec { let mut res = Vec::new(); let mut acc = F::ONE; let chunk_size = max_degree; - let num_chunks = v.len() / chunk_size; - for i in 0..num_chunks { - acc *= v[i * chunk_size..(i + 1) * chunk_size] - .iter() - .copied() - .product(); + for chunk in v.chunks_exact(chunk_size) { + acc *= chunk.iter().copied().product(); res.push(acc); } @@ -40,12 +36,8 @@ pub fn check_partial_products(v: &[F], partials: &[F], max_degree: usi let mut res = Vec::new(); let mut acc = F::ONE; let chunk_size = max_degree; - let num_chunks = v.len() / chunk_size; - for i in 0..num_chunks { - acc *= v[i * chunk_size..(i + 1) * chunk_size] - .iter() - .copied() - .product(); + for chunk in v.chunks_exact(chunk_size) { + acc *= chunk.iter().copied().product(); let new_acc = *partials.next().unwrap(); res.push(acc - new_acc); acc = new_acc; @@ -66,13 +58,11 @@ pub fn check_partial_products_recursively, const D: let mut res = Vec::new(); let mut acc = builder.one_extension(); let chunk_size = max_degree; - let num_chunks = v.len() / chunk_size; - for i in 0..num_chunks { - let mut chunk = v[i * chunk_size..(i + 1) * chunk_size].to_vec(); - chunk.push(acc); - acc = builder.mul_many_extension(&chunk); + for chunk in v.chunks_exact(chunk_size) { + let chunk_product = builder.mul_many_extension(chunk); let new_acc = *partials.next().unwrap(); - res.push(builder.sub_extension(acc, new_acc)); + // Assert that new_acc = acc * chunk_product. + res.push(builder.mul_sub_extension(acc, chunk_product, new_acc)); acc = new_acc; } debug_assert!(partials.next().is_none()); From 32f09ac2dfc026f4c3e9771cdaa293f59e21de0c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 10 Nov 2021 18:13:27 +0100 Subject: [PATCH 018/202] Remove quotients and work directly with numerators and denominators in partial products check --- src/plonk/vanishing_poly.rs | 99 +++++++++++++----------------------- src/util/partial_products.rs | 41 ++++++++++----- 2 files changed, 64 insertions(+), 76 deletions(-) diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 4976eaba..43e52994 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -62,31 +62,26 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( wire_value + s_sigma.scalar_mul(betas[i]) + gammas[i].into() }) .collect::>(); - let quotient_values = (0..common_data.config.num_routed_wires) - .map(|j| numerator_values[j] / denominator_values[j]) - .collect::>(); // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the quotient partial products. - let mut partial_product_checks = - check_partial_products("ient_values, current_partial_products, max_degree); - // The partial products are products of quotients, so we multiply them by the product of the - // corresponding denominators to make sure they are polynomials. - for (j, partial_product_check) in partial_product_checks.iter_mut().enumerate() { - let range = j * max_degree..(j + 1) * max_degree; - *partial_product_check *= denominator_values[range].iter().copied().product(); - } + let partial_product_checks = check_partial_products( + &numerator_values, + &denominator_values, + current_partial_products, + max_degree, + ); vanishing_partial_products_terms.extend(partial_product_checks); - let quotient: F::Extension = *current_partial_products.last().unwrap() - * quotient_values[final_num_prod..].iter().copied().product(); - let mut v_shift_term = quotient * z_x - z_gz; - // Need to multiply by the denominators to make sure we get a polynomial. - v_shift_term *= denominator_values[final_num_prod..] - .iter() - .copied() - .product(); + let v_shift_term = *current_partial_products.last().unwrap() + * numerator_values[final_num_prod..].iter().copied().product() + * z_x + - z_gz + * denominator_values[final_num_prod..] + .iter() + .copied() + .product(); vanishing_v_shift_terms.push(v_shift_term); } @@ -139,7 +134,6 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let mut numerator_values = Vec::with_capacity(num_routed_wires); let mut denominator_values = Vec::with_capacity(num_routed_wires); - let mut quotient_values = Vec::with_capacity(num_routed_wires); // The L_1(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges); @@ -178,37 +172,30 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let s_sigma = s_sigmas[j]; wire_value + betas[i] * s_sigma + gammas[i] })); - let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values); - quotient_values.extend( - (0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]), - ); // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the numerator partial products. - let mut partial_product_checks = - check_partial_products("ient_values, current_partial_products, max_degree); - // The partial products are products of quotients, so we multiply them by the product of the - // corresponding denominators to make sure they are polynomials. - for (j, partial_product_check) in partial_product_checks.iter_mut().enumerate() { - let range = j * max_degree..(j + 1) * max_degree; - *partial_product_check *= denominator_values[range].iter().copied().product(); - } + let partial_product_checks = check_partial_products( + &numerator_values, + &denominator_values, + current_partial_products, + max_degree, + ); vanishing_partial_products_terms.extend(partial_product_checks); - let quotient: F = *current_partial_products.last().unwrap() - * quotient_values[final_num_prod..].iter().copied().product(); - let mut v_shift_term = quotient * z_x - z_gz; - // Need to multiply by the denominators to make sure we get a polynomial. - v_shift_term *= denominator_values[final_num_prod..] - .iter() - .copied() - .product(); + let v_shift_term = *current_partial_products.last().unwrap() + * numerator_values[final_num_prod..].iter().copied().product() + * z_x + - z_gz + * denominator_values[final_num_prod..] + .iter() + .copied() + .product(); vanishing_v_shift_terms.push(v_shift_term); numerator_values.clear(); denominator_values.clear(); - quotient_values.clear(); } let vanishing_terms = vanishing_z_1_terms @@ -365,7 +352,6 @@ pub(crate) fn eval_vanishing_poly_recursively, cons let mut numerator_values = Vec::new(); let mut denominator_values = Vec::new(); - let mut quotient_values = Vec::new(); for j in 0..common_data.config.num_routed_wires { let wire_value = vars.local_wires[j]; @@ -378,46 +364,33 @@ pub(crate) fn eval_vanishing_poly_recursively, cons let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma); let denominator = builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma); - let quotient = builder.div_extension(numerator, denominator); - numerator_values.push(numerator); denominator_values.push(denominator); - quotient_values.push(quotient); } // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the quotient partial products. - let mut partial_product_checks = check_partial_products_recursively( + let partial_product_checks = check_partial_products_recursively( builder, - "ient_values, + &numerator_values, + &denominator_values, current_partial_products, max_degree, ); - // The partial products are products of quotients, so we multiply them by the product of the - // corresponding denominators to make sure they are polynomials. - for (j, partial_product_check) in partial_product_checks.iter_mut().enumerate() { - let range = j * max_degree..(j + 1) * max_degree; - *partial_product_check = builder.mul_many_extension(&{ - let mut v = denominator_values[range].to_vec(); - v.push(*partial_product_check); - v - }); - } vanishing_partial_products_terms.extend(partial_product_checks); - let quotient = builder.mul_many_extension(&{ - let mut v = quotient_values[final_num_prod..].to_vec(); + let nume_acc = builder.mul_many_extension(&{ + let mut v = numerator_values[final_num_prod..].to_vec(); v.push(*current_partial_products.last().unwrap()); v }); - let mut v_shift_term = builder.mul_sub_extension(quotient, z_x, z_gz); - // Need to multiply by the denominators to make sure we get a polynomial. - v_shift_term = builder.mul_many_extension(&{ + let z_gz_denominators = builder.mul_many_extension(&{ let mut v = denominator_values[final_num_prod..].to_vec(); - v.push(v_shift_term); + v.push(z_gz); v }); + let v_shift_term = builder.mul_sub_extension(nume_acc, z_x, z_gz_denominators); vanishing_v_shift_terms.push(v_shift_term); } diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 37b51825..c3c4659a 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -28,18 +28,26 @@ pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { (num_chunks, num_chunks * chunk_size) } -/// Checks that the partial products of `v` are coherent with those in `partials` by only computing +/// Checks that the partial products of `numerators/denominators` are coherent with those in `partials` by only computing /// products of size `max_degree` or less. -pub fn check_partial_products(v: &[F], partials: &[F], max_degree: usize) -> Vec { +pub fn check_partial_products( + numerators: &[F], + denominators: &[F], + partials: &[F], + max_degree: usize, +) -> Vec { debug_assert!(max_degree > 1); let mut partials = partials.iter(); let mut res = Vec::new(); let mut acc = F::ONE; let chunk_size = max_degree; - for chunk in v.chunks_exact(chunk_size) { - acc *= chunk.iter().copied().product(); - let new_acc = *partials.next().unwrap(); - res.push(acc - new_acc); + for (nume_chunk, deno_chunk) in numerators + .chunks_exact(chunk_size) + .zip(denominators.chunks_exact(chunk_size)) + { + acc *= nume_chunk.iter().copied().product(); + let mut new_acc = *partials.next().unwrap(); + res.push(acc - new_acc * deno_chunk.iter().copied().product()); acc = new_acc; } debug_assert!(partials.next().is_none()); @@ -49,7 +57,8 @@ pub fn check_partial_products(v: &[F], partials: &[F], max_degree: usi pub fn check_partial_products_recursively, const D: usize>( builder: &mut CircuitBuilder, - v: &[ExtensionTarget], + numerators: &[ExtensionTarget], + denominators: &[ExtensionTarget], partials: &[ExtensionTarget], max_degree: usize, ) -> Vec> { @@ -58,11 +67,16 @@ pub fn check_partial_products_recursively, const D: let mut res = Vec::new(); let mut acc = builder.one_extension(); let chunk_size = max_degree; - for chunk in v.chunks_exact(chunk_size) { - let chunk_product = builder.mul_many_extension(chunk); + for (nume_chunk, deno_chunk) in numerators + .chunks_exact(chunk_size) + .zip(denominators.chunks_exact(chunk_size)) + { + let nume_product = builder.mul_many_extension(nume_chunk); + let deno_product = builder.mul_many_extension(deno_chunk); let new_acc = *partials.next().unwrap(); - // Assert that new_acc = acc * chunk_product. - res.push(builder.mul_sub_extension(acc, chunk_product, new_acc)); + let new_acc_deno = builder.mul_extension(new_acc, deno_product); + // Assert that new_acc*deno_product = acc * nume_product. + res.push(builder.mul_sub_extension(acc, nume_product, new_acc_deno)); acc = new_acc; } debug_assert!(partials.next().is_none()); @@ -78,6 +92,7 @@ mod tests { #[test] fn test_partial_products() { type F = GoldilocksField; + let denominators = vec![F::ONE; 6]; let v = [1, 2, 3, 4, 5, 6] .into_iter() .map(|&i| F::from_canonical_u64(i)) @@ -93,7 +108,7 @@ mod tests { let nums = num_partial_products(v.len(), 2); assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &p, 2) + assert!(check_partial_products(&v, &denominators, &p, 2) .iter() .all(|x| x.is_zero())); assert_eq!( @@ -115,7 +130,7 @@ mod tests { ); let nums = num_partial_products(v.len(), 3); assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &p, 3) + assert!(check_partial_products(&v, &denominators, &p, 3) .iter() .all(|x| x.is_zero())); assert_eq!( From 3084367133ad03c7bc74973975dc1d313a338fda Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 10 Nov 2021 18:36:35 +0100 Subject: [PATCH 019/202] Start accumulator at Z(x) --- src/plonk/prover.rs | 14 ++++++++------ src/plonk/vanishing_poly.rs | 7 ++++--- src/util/partial_products.rs | 8 ++++---- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 6c57217c..be356d9f 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -100,7 +100,7 @@ pub(crate) fn prove, const D: usize>( let plonk_z_vecs = timed!( timing, "compute Z's", - compute_zs(&partial_products, common_data) + compute_zs(&mut partial_products, common_data) ); // The first polynomial in `partial_products` represent the final product used in the @@ -286,24 +286,26 @@ fn wires_permutation_partial_products, const D: usi } fn compute_zs, const D: usize>( - partial_products: &[Vec>], + partial_products: &mut [Vec>], common_data: &CommonCircuitData, ) -> Vec> { (0..common_data.config.num_challenges) - .map(|i| compute_z(&partial_products[i], common_data)) + .map(|i| compute_z(&mut partial_products[i], common_data)) .collect() } /// Compute the `Z` polynomial by reusing the computations done in `wires_permutation_partial_products`. fn compute_z, const D: usize>( - partial_products: &[PolynomialValues], + partial_products: &mut [PolynomialValues], common_data: &CommonCircuitData, ) -> PolynomialValues { let mut plonk_z_points = vec![F::ONE]; for i in 1..common_data.degree() { - let quotient = partial_products[0].values[i - 1]; let last = *plonk_z_points.last().unwrap(); - plonk_z_points.push(last * quotient); + for q in partial_products.iter_mut() { + q.values[i - 1] *= last; + } + plonk_z_points.push(partial_products[0].values[i - 1]); } plonk_z_points.into() } diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 43e52994..899d69a6 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -70,13 +70,13 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( &numerator_values, &denominator_values, current_partial_products, + z_x, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); let v_shift_term = *current_partial_products.last().unwrap() * numerator_values[final_num_prod..].iter().copied().product() - * z_x - z_gz * denominator_values[final_num_prod..] .iter() @@ -180,13 +180,13 @@ pub(crate) fn eval_vanishing_poly_base_batch, const &numerator_values, &denominator_values, current_partial_products, + z_x, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); let v_shift_term = *current_partial_products.last().unwrap() * numerator_values[final_num_prod..].iter().copied().product() - * z_x - z_gz * denominator_values[final_num_prod..] .iter() @@ -376,6 +376,7 @@ pub(crate) fn eval_vanishing_poly_recursively, cons &numerator_values, &denominator_values, current_partial_products, + z_x, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); @@ -390,7 +391,7 @@ pub(crate) fn eval_vanishing_poly_recursively, cons v.push(z_gz); v }); - let v_shift_term = builder.mul_sub_extension(nume_acc, z_x, z_gz_denominators); + let v_shift_term = builder.sub_extension(nume_acc, z_gz_denominators); vanishing_v_shift_terms.push(v_shift_term); } diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index c3c4659a..92419a56 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -34,12 +34,12 @@ pub fn check_partial_products( numerators: &[F], denominators: &[F], partials: &[F], + mut acc: F, max_degree: usize, ) -> Vec { debug_assert!(max_degree > 1); let mut partials = partials.iter(); let mut res = Vec::new(); - let mut acc = F::ONE; let chunk_size = max_degree; for (nume_chunk, deno_chunk) in numerators .chunks_exact(chunk_size) @@ -60,12 +60,12 @@ pub fn check_partial_products_recursively, const D: numerators: &[ExtensionTarget], denominators: &[ExtensionTarget], partials: &[ExtensionTarget], + mut acc: ExtensionTarget, max_degree: usize, ) -> Vec> { debug_assert!(max_degree > 1); let mut partials = partials.iter(); let mut res = Vec::new(); - let mut acc = builder.one_extension(); let chunk_size = max_degree; for (nume_chunk, deno_chunk) in numerators .chunks_exact(chunk_size) @@ -108,7 +108,7 @@ mod tests { let nums = num_partial_products(v.len(), 2); assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &denominators, &p, 2) + assert!(check_partial_products(&v, &denominators, &p, F::ONE, 2) .iter() .all(|x| x.is_zero())); assert_eq!( @@ -130,7 +130,7 @@ mod tests { ); let nums = num_partial_products(v.len(), 3); assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &denominators, &p, 3) + assert!(check_partial_products(&v, &denominators, &p, F::ONE, 3) .iter() .all(|x| x.is_zero())); assert_eq!( From 8440a0f5cbe6cb16eacc649ef0eb2c77b3493250 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 09:53:09 -0800 Subject: [PATCH 020/202] merge --- src/plonk/circuit_builder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index e6020ad8..4879d144 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -87,7 +87,7 @@ pub struct CircuitBuilder, const D: usize> { // of switches pub(crate) current_switch_gates: Vec, usize, usize)>>, - // The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one) + /// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one) pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>, /// An available `ConstantGate` instance, if any. From ebcfde1d81892fc16aba4a6cac864c56a232b3a4 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 7 Oct 2021 09:25:57 -0700 Subject: [PATCH 021/202] updates --- src/gadgets/arithmetic_u32.rs | 42 ++++++++++------------------------- src/gates/arithmetic_u32.rs | 6 ++--- 2 files changed, 14 insertions(+), 34 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index 7c5fe218..213cdffe 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -1,15 +1,10 @@ -use std::collections::BTreeMap; use std::marker::PhantomData; use crate::field::field_types::RichField; -use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::field::extension_field::Extendable; use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; -use crate::gates::switch::SwitchGate; -use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; -use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::util::bimap::bimap_from_lists; pub struct U32Target(Target); @@ -20,8 +15,7 @@ impl, const D: usize> CircuitBuilder { pub fn add_virtual_u32_targets(&self, n: usize) -> Vec { self.add_virtual_targets(n) - .iter() - .cloned() + .into_iter() .map(U32Target) .collect() } @@ -52,9 +46,6 @@ impl, const D: usize> CircuitBuilder { Some((gate_index, copy)) => (gate_index, copy), }; - let output_low = self.add_virtual_u32_target(); - let output_high = self.add_virtual_u32_target(); - self.connect( Target::wire( gate_index, @@ -73,27 +64,18 @@ impl, const D: usize> CircuitBuilder { Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(copy)), z.0, ); - self.connect( - Target::wire( - gate_index, - U32ArithmeticGate::::wire_ith_output_low_half(copy), - ), - output_low.0, - ); - self.connect( - Target::wire( - gate_index, - U32ArithmeticGate::::wire_ith_output_high_half(copy), - ), - output_high.0, - ); + + let output_low = U32Target(Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_output_low_half(copy), + )); + let output_high = U32Target(Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_output_high_half(copy), + )); if copy == NUM_U32_ARITHMETIC_OPS - 1 { - let gate = U32ArithmeticGate { - _phantom: PhantomData, - }; - let gate_index = self.add_gate(gate, vec![]); - self.current_u32_arithmetic_gate = Some((gate_index, 0)); + self.current_u32_arithmetic_gate = None; } else { self.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); } diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs index c05cf72f..b56af0e8 100644 --- a/src/gates/arithmetic_u32.rs +++ b/src/gates/arithmetic_u32.rs @@ -309,8 +309,7 @@ impl, const D: usize> SimpleGenerator .take(num_limbs) .collect(); let output_limbs_f: Vec<_> = output_limbs_u64 - .iter() - .cloned() + .into_iter() .map(F::from_canonical_u64) .collect(); @@ -385,8 +384,7 @@ mod tests { output /= limb_base; } let mut output_limbs_f: Vec<_> = output_limbs - .iter() - .cloned() + .into_iter() .map(F::from_canonical_u64) .collect(); From 912204d6858f6871592750f475ffe637dd908849 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 09:53:27 -0800 Subject: [PATCH 022/202] merge --- src/gadgets/arithmetic_u32.rs | 2 +- src/gadgets/nonnative.rs | 40 +++++++++++++++++++---------------- src/plonk/circuit_builder.rs | 6 ++++++ 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index 213cdffe..489a9d5b 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -6,7 +6,7 @@ use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; -pub struct U32Target(Target); +pub struct U32Target(pub Target); impl, const D: usize> CircuitBuilder { pub fn add_virtual_u32_target(&self) -> U32Target { diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 6da8ab5e..cf232632 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -13,19 +13,23 @@ use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::bimap::bimap_from_lists; - -pub struct NonNativeTarget { - /// The modulus of the field F' being represented. - modulus: BigUint, + +pub struct ForeignFieldTarget { /// These F elements are assumed to contain 32-bit values. limbs: Vec, + _phantom: PhantomData, } impl, const D: usize> CircuitBuilder { - pub fn add_nonnative(&mut self, a: NonNativeTarget, b: NonNativeTarget) -> NonNativeTarget { - let modulus = a.modulus; + pub fn order_u32_limbs(&self) -> Vec { + let modulus = FF::order(); + let limbs = modulus.to_u32_digits(); + limbs.iter().map(|&limb| self.constant_u32(F::from_canonical_u32(limb))).collect() + } + + // Add two `ForeignFieldTarget`s, which we assume are both normalized. + pub fn add_nonnative(&mut self, a: ForeignFieldTarget, b: ForeignFieldTarget) -> ForeignFieldTarget { let num_limbs = a.limbs.len(); - debug_assert!(b.modulus == modulus); debug_assert!(b.limbs.len() == num_limbs); let mut combined_limbs = self.add_virtual_u32_targets(num_limbs + 1); @@ -36,21 +40,21 @@ impl, const D: usize> CircuitBuilder { } combined_limbs[num_limbs] = carry; - let reduced_limbs = self.reduce_add_result(combined_limbs, modulus); - NonNativeTarget { - modulus, + let reduced_limbs = self.reduce_add_result::(combined_limbs); + ForeignFieldTarget { limbs: reduced_limbs, + _phantom: PhantomData, } } - pub fn reduce_add_result(&mut self, limbs: Vec, modulus: BigUint) -> Vec { + pub fn reduce_add_result(&mut self, limbs: Vec) -> Vec { + let modulus = FF::order(); + todo!() } - pub fn mul_nonnative(&mut self, a: NonNativeTarget, b: NonNativeTarget) -> NonNativeTarget { - let modulus = a.modulus; + pub fn mul_nonnative(&mut self, a: ForeignFieldTarget, b: ForeignFieldTarget) -> ForeignFieldTarget { let num_limbs = a.limbs.len(); - debug_assert!(b.modulus == modulus); debug_assert!(b.limbs.len() == num_limbs); let mut combined_limbs = self.add_virtual_targets(2 * num_limbs - 1); @@ -61,15 +65,15 @@ impl, const D: usize> CircuitBuilder { } } - let reduced_limbs = self.reduce_mul_result(combined_limbs, modulus); + let reduced_limbs = self.reduce_mul_result::(combined_limbs); - NonNativeTarget { - modulus, + ForeignFieldTarget { limbs: reduced_limbs, + _phantom: PhantomData, } } - pub fn reduce_mul_result(&mut self, limbs: Vec, modulus: BigUint) -> Vec { + pub fn reduce_mul_result(&mut self, limbs: Vec) -> Vec { todo!() } } diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 4879d144..019bc71e 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -13,6 +13,7 @@ use crate::field::field_types::{Field, RichField}; use crate::fri::commitment::PolynomialBatchCommitment; use crate::fri::{FriConfig, FriParams}; use crate::gadgets::arithmetic_extension::ArithmeticOperation; +use crate::gadgets::arithmetic_u32::U32Target; use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; @@ -349,6 +350,11 @@ impl, const D: usize> CircuitBuilder { } } + /// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits. + pub fn constant_u32(&mut self, c: F) -> U32Target { + U32Target(self.constant(c)) + } + /// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns /// its constant value. Otherwise, returns `None`. pub fn target_as_constant(&self, target: Target) -> Option { From 7e8c021b46403178f7c90ae363630db885296ab7 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 12 Oct 2021 11:41:34 -0700 Subject: [PATCH 023/202] comparison gate --- src/gadgets/arithmetic_u32.rs | 19 +- src/gadgets/mod.rs | 1 + src/gadgets/nonnative.rs | 13 +- src/gadgets/sorting.rs | 6 +- src/gates/arithmetic_u32.rs | 6 + src/gates/assert_le.rs | 601 ++++++++++++++++++++++++++++++++++ src/gates/comparison.rs | 154 ++++++--- src/gates/mod.rs | 1 + 8 files changed, 739 insertions(+), 62 deletions(-) create mode 100644 src/gates/assert_le.rs diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index 489a9d5b..cb8334db 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -6,25 +6,26 @@ use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; +#[derive(Clone)] pub struct U32Target(pub Target); impl, const D: usize> CircuitBuilder { - pub fn add_virtual_u32_target(&self) -> U32Target { + pub fn add_virtual_u32_target(&mut self) -> U32Target { U32Target(self.add_virtual_target()) } - pub fn add_virtual_u32_targets(&self, n: usize) -> Vec { + pub fn add_virtual_u32_targets(&mut self, n: usize) -> Vec { self.add_virtual_targets(n) .into_iter() .map(U32Target) .collect() } - pub fn zero_u32(&self) -> U32Target { + pub fn zero_u32(&mut self) -> U32Target { U32Target(self.zero()) } - pub fn one_u32(&self) -> U32Target { + pub fn one_u32(&mut self) -> U32Target { U32Target(self.one()) } @@ -37,9 +38,7 @@ impl, const D: usize> CircuitBuilder { ) -> (U32Target, U32Target) { let (gate_index, copy) = match self.current_u32_arithmetic_gate { None => { - let gate = U32ArithmeticGate { - _phantom: PhantomData, - }; + let gate = U32ArithmeticGate::new(); let gate_index = self.add_gate(gate, vec![]); (gate_index, 0) } @@ -84,7 +83,8 @@ impl, const D: usize> CircuitBuilder { } pub fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { - self.mul_add_u32(a, self.one_u32(), b) + let one = self.one_u32(); + self.mul_add_u32(a, one, b) } pub fn add_three_u32( @@ -100,6 +100,7 @@ impl, const D: usize> CircuitBuilder { } pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { - self.mul_add_u32(a, b, self.zero_u32()) + let zero = self.zero_u32(); + self.mul_add_u32(a, b, zero) } } diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index e38646e3..9fb572c9 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -4,6 +4,7 @@ pub mod arithmetic_u32; pub mod hash; pub mod insert; pub mod interpolation; +pub mod multiple_comparison; pub mod nonnative; pub mod permutation; pub mod polynomial; diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index cf232632..418aae82 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -21,7 +21,7 @@ pub struct ForeignFieldTarget { } impl, const D: usize> CircuitBuilder { - pub fn order_u32_limbs(&self) -> Vec { + pub fn order_u32_limbs(&mut self) -> Vec { let modulus = FF::order(); let limbs = modulus.to_u32_digits(); limbs.iter().map(|&limb| self.constant_u32(F::from_canonical_u32(limb))).collect() @@ -35,7 +35,7 @@ impl, const D: usize> CircuitBuilder { let mut combined_limbs = self.add_virtual_u32_targets(num_limbs + 1); let mut carry = self.zero_u32(); for i in 0..num_limbs { - let (new_limb, carry) = self.add_three_u32(carry, a.limbs[i], b.limbs[i]); + let (new_limb, carry) = self.add_three_u32(carry.clone(), a.limbs[i].clone(), b.limbs[i].clone()); combined_limbs[i] = new_limb; } combined_limbs[num_limbs] = carry; @@ -48,8 +48,6 @@ impl, const D: usize> CircuitBuilder { } pub fn reduce_add_result(&mut self, limbs: Vec) -> Vec { - let modulus = FF::order(); - todo!() } @@ -57,11 +55,11 @@ impl, const D: usize> CircuitBuilder { let num_limbs = a.limbs.len(); debug_assert!(b.limbs.len() == num_limbs); - let mut combined_limbs = self.add_virtual_targets(2 * num_limbs - 1); + /*let mut combined_limbs = self.add_virtual_u32_targets(2 * num_limbs - 1); for i in 0..num_limbs { for j in 0..num_limbs { let sum = self.add_u32(a.limbs[i], b.limbs[j]); - combined_limbs[i + j] = self.add(combined_limbs[i + j], sum); + combined_limbs[i + j] = self.add_u32(combined_limbs[i + j], sum); } } @@ -70,7 +68,8 @@ impl, const D: usize> CircuitBuilder { ForeignFieldTarget { limbs: reduced_limbs, _phantom: PhantomData, - } + }*/ + todo!() } pub fn reduce_mul_result(&mut self, limbs: Vec) -> Vec { diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 72dcf273..2c52db23 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -4,7 +4,7 @@ use itertools::izip; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; -use crate::gates::comparison::ComparisonGate; +use crate::gates::assert_le::AssertLessThanGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; @@ -40,9 +40,9 @@ impl, const D: usize> CircuitBuilder { self.assert_permutation(a_chunks, b_chunks); } - /// Add a ComparisonGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits. + /// Add an AssertLessThanGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits. pub fn assert_le(&mut self, lhs: Target, rhs: Target, bits: usize, num_chunks: usize) { - let gate = ComparisonGate::new(bits, num_chunks); + let gate = AssertLessThanGate::new(bits, num_chunks); let gate_index = self.add_gate(gate.clone(), vec![]); self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs); diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs index b56af0e8..2bbbda6e 100644 --- a/src/gates/arithmetic_u32.rs +++ b/src/gates/arithmetic_u32.rs @@ -23,6 +23,12 @@ pub struct U32ArithmeticGate, const D: usize> { } impl, const D: usize> U32ArithmeticGate { + pub fn new() -> Self { + Self { + _phantom: PhantomData, + } + } + pub fn wire_ith_multiplicand_0(i: usize) -> usize { debug_assert!(i < NUM_U32_ARITHMETIC_OPS); 5 * i diff --git a/src/gates/assert_le.rs b/src/gates/assert_le.rs new file mode 100644 index 00000000..796f94b3 --- /dev/null +++ b/src/gates/assert_le.rs @@ -0,0 +1,601 @@ +use std::marker::PhantomData; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, PrimeField, RichField}; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive}; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::util::{bits_u64, ceil_div_usize}; + +/// A gate for checking that one value is less than or equal to another. +#[derive(Clone, Debug)] +pub struct AssertLessThanGate, const D: usize> { + pub(crate) num_bits: usize, + pub(crate) num_chunks: usize, + _phantom: PhantomData, +} + +impl, const D: usize> AssertLessThanGate { + pub fn new(num_bits: usize, num_chunks: usize) -> Self { + debug_assert!(num_bits < bits_u64(F::ORDER)); + Self { + num_bits, + num_chunks, + _phantom: PhantomData, + } + } + + pub fn chunk_bits(&self) -> usize { + ceil_div_usize(self.num_bits, self.num_chunks) + } + + pub fn wire_first_input(&self) -> usize { + 0 + } + + pub fn wire_second_input(&self) -> usize { + 1 + } + + pub fn wire_most_significant_diff(&self) -> usize { + 2 + } + + pub fn wire_first_chunk_val(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + chunk + } + + pub fn wire_second_chunk_val(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + self.num_chunks + chunk + } + + pub fn wire_equality_dummy(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + 2 * self.num_chunks + chunk + } + + pub fn wire_chunks_equal(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + 3 * self.num_chunks + chunk + } + + pub fn wire_intermediate_value(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + 4 * self.num_chunks + chunk + } +} + +impl, const D: usize> Gate for AssertLessThanGate { + fn id(&self) -> String { + format!("{:?}", self, D) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let first_chunks_combined = reduce_with_powers( + &first_chunks, + F::Extension::from_canonical_usize(1 << self.chunk_bits()), + ); + let second_chunks_combined = reduce_with_powers( + &second_chunks, + F::Extension::from_canonical_usize(1 << self.chunk_bits()), + ); + + constraints.push(first_chunks_combined - first_input); + constraints.push(second_chunks_combined - second_input); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = F::Extension::ZERO; + + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let first_product = (0..chunk_size) + .map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x)) + .product(); + let second_product = (0..chunk_size) + .map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(first_product); + constraints.push(second_product); + + let difference = second_chunks[i] - first_chunks[i]; + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal)); + constraints.push(chunks_equal * difference); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); + most_significant_diff_so_far = + intermediate_value + (F::Extension::ONE - chunks_equal) * difference; + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints.push(most_significant_diff - most_significant_diff_so_far); + + // Range check `most_significant_diff` to be less than `chunk_size`. + let product = (0..chunk_size) + .map(|x| most_significant_diff - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(product); + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let first_chunks_combined = reduce_with_powers( + &first_chunks, + F::from_canonical_usize(1 << self.chunk_bits()), + ); + let second_chunks_combined = reduce_with_powers( + &second_chunks, + F::from_canonical_usize(1 << self.chunk_bits()), + ); + + constraints.push(first_chunks_combined - first_input); + constraints.push(second_chunks_combined - second_input); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = F::ZERO; + + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let first_product = (0..chunk_size) + .map(|x| first_chunks[i] - F::from_canonical_usize(x)) + .product(); + let second_product = (0..chunk_size) + .map(|x| second_chunks[i] - F::from_canonical_usize(x)) + .product(); + constraints.push(first_product); + constraints.push(second_product); + + let difference = second_chunks[i] - first_chunks[i]; + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + constraints.push(difference * equality_dummy - (F::ONE - chunks_equal)); + constraints.push(chunks_equal * difference); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); + most_significant_diff_so_far = + intermediate_value + (F::ONE - chunks_equal) * difference; + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints.push(most_significant_diff - most_significant_diff_so_far); + + // Range check `most_significant_diff` to be less than `chunk_size`. + let product = (0..chunk_size) + .map(|x| most_significant_diff - F::from_canonical_usize(x)) + .product(); + constraints.push(product); + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits())); + let first_chunks_combined = + reduce_with_powers_ext_recursive(builder, &first_chunks, chunk_base); + let second_chunks_combined = + reduce_with_powers_ext_recursive(builder, &second_chunks, chunk_base); + + constraints.push(builder.sub_extension(first_chunks_combined, first_input)); + constraints.push(builder.sub_extension(second_chunks_combined, second_input)); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = builder.zero_extension(); + + let one = builder.one_extension(); + // Find the chosen chunk. + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let mut first_product = one; + let mut second_product = one; + for x in 0..chunk_size { + let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); + let first_diff = builder.sub_extension(first_chunks[i], x_f); + let second_diff = builder.sub_extension(second_chunks[i], x_f); + first_product = builder.mul_extension(first_product, first_diff); + second_product = builder.mul_extension(second_product, second_diff); + } + constraints.push(first_product); + constraints.push(second_product); + + let difference = builder.sub_extension(second_chunks[i], first_chunks[i]); + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + let diff_times_equal = builder.mul_extension(difference, equality_dummy); + let not_equal = builder.sub_extension(one, chunks_equal); + constraints.push(builder.sub_extension(diff_times_equal, not_equal)); + constraints.push(builder.mul_extension(chunks_equal, difference)); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far); + constraints.push(builder.sub_extension(intermediate_value, old_diff)); + + let not_equal = builder.sub_extension(one, chunks_equal); + let new_diff = builder.mul_extension(not_equal, difference); + most_significant_diff_so_far = builder.add_extension(intermediate_value, new_diff); + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints + .push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far)); + + // Range check `most_significant_diff` to be less than `chunk_size`. + let mut product = builder.one_extension(); + for x in 0..chunk_size { + let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); + let diff = builder.sub_extension(most_significant_diff, x_f); + product = builder.mul_extension(product, diff); + } + constraints.push(product); + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + let gen = AssertLessThanGenerator:: { + gate_index, + gate: self.clone(), + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.wire_intermediate_value(self.num_chunks - 1) + 1 + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 << self.chunk_bits() + } + + fn num_constraints(&self) -> usize { + 4 + 5 * self.num_chunks + } +} + +#[derive(Debug)] +struct AssertLessThanGenerator, const D: usize> { + gate_index: usize, + gate: AssertLessThanGate, +} + +impl, const D: usize> SimpleGenerator + for AssertLessThanGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |input| Target::wire(self.gate_index, input); + + let mut deps = Vec::new(); + deps.push(local_target(self.gate.wire_first_input())); + deps.push(local_target(self.gate.wire_second_input())); + deps + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let get_local_wire = |input| witness.get_wire(local_wire(input)); + + let first_input = get_local_wire(self.gate.wire_first_input()); + let second_input = get_local_wire(self.gate.wire_second_input()); + + let first_input_u64 = first_input.to_canonical_u64(); + let second_input_u64 = second_input.to_canonical_u64(); + + debug_assert!(first_input_u64 < second_input_u64); + + let chunk_size = 1 << self.gate.chunk_bits(); + let first_input_chunks: Vec = (0..self.gate.num_chunks) + .scan(first_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + let second_input_chunks: Vec = (0..self.gate.num_chunks) + .scan(second_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + let chunks_equal: Vec = (0..self.gate.num_chunks) + .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) + .collect(); + let equality_dummies: Vec = first_input_chunks + .iter() + .zip(second_input_chunks.iter()) + .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) + .collect(); + + let mut most_significant_diff_so_far = F::ZERO; + let mut intermediate_values = Vec::new(); + for i in 0..self.gate.num_chunks { + if first_input_chunks[i] != second_input_chunks[i] { + most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; + intermediate_values.push(F::ZERO); + } else { + intermediate_values.push(most_significant_diff_so_far); + } + } + let most_significant_diff = most_significant_diff_so_far; + + out_buffer.set_wire( + local_wire(self.gate.wire_most_significant_diff()), + most_significant_diff, + ); + for i in 0..self.gate.num_chunks { + out_buffer.set_wire( + local_wire(self.gate.wire_first_chunk_val(i)), + first_input_chunks[i], + ); + out_buffer.set_wire( + local_wire(self.gate.wire_second_chunk_val(i)), + second_input_chunks[i], + ); + out_buffer.set_wire( + local_wire(self.gate.wire_equality_dummy(i)), + equality_dummies[i], + ); + out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]); + out_buffer.set_wire( + local_wire(self.gate.wire_intermediate_value(i)), + intermediate_values[i], + ); + } + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use rand::Rng; + + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::{Field, PrimeField}; + use crate::gates::assert_le::AssertLessThanGate; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::hash::hash_types::HashOut; + use crate::plonk::vars::EvaluationVars; + + #[test] + fn wire_indices() { + type AG = AssertLessThanGate; + let num_bits = 40; + let num_chunks = 5; + + let gate = AG { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + + assert_eq!(gate.wire_first_input(), 0); + assert_eq!(gate.wire_second_input(), 1); + assert_eq!(gate.wire_most_significant_diff(), 2); + assert_eq!(gate.wire_first_chunk_val(0), 3); + assert_eq!(gate.wire_first_chunk_val(4), 7); + assert_eq!(gate.wire_second_chunk_val(0), 8); + assert_eq!(gate.wire_second_chunk_val(4), 12); + assert_eq!(gate.wire_equality_dummy(0), 13); + assert_eq!(gate.wire_equality_dummy(4), 17); + assert_eq!(gate.wire_chunks_equal(0), 18); + assert_eq!(gate.wire_chunks_equal(4), 22); + assert_eq!(gate.wire_intermediate_value(0), 23); + assert_eq!(gate.wire_intermediate_value(4), 27); + } + + #[test] + fn low_degree() { + let num_bits = 40; + let num_chunks = 5; + + test_low_degree::(AssertLessThanGate::<_, 4>::new(num_bits, num_chunks)) + } + + #[test] + fn eval_fns() -> Result<()> { + let num_bits = 40; + let num_chunks = 5; + + test_eval_fns::(AssertLessThanGate::<_, 4>::new(num_bits, num_chunks)) + } + + #[test] + fn test_gate_constraint() { + type F = CrandallField; + type FF = QuarticExtension; + const D: usize = 4; + + let num_bits = 40; + let num_chunks = 5; + let chunk_bits = num_bits / num_chunks; + + // Returns the local wires for an AssertLessThanGate given the two inputs. + let get_wires = |first_input: F, second_input: F| -> Vec { + let mut v = Vec::new(); + + let first_input_u64 = first_input.to_canonical_u64(); + let second_input_u64 = second_input.to_canonical_u64(); + + let chunk_size = 1 << chunk_bits; + let mut first_input_chunks: Vec = (0..num_chunks) + .scan(first_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + let mut second_input_chunks: Vec = (0..num_chunks) + .scan(second_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + let mut chunks_equal: Vec = (0..num_chunks) + .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) + .collect(); + let mut equality_dummies: Vec = first_input_chunks + .iter() + .zip(second_input_chunks.iter()) + .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) + .collect(); + + let mut most_significant_diff_so_far = F::ZERO; + let mut intermediate_values = Vec::new(); + for i in 0..num_chunks { + if first_input_chunks[i] != second_input_chunks[i] { + most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; + intermediate_values.push(F::ZERO); + } else { + intermediate_values.push(most_significant_diff_so_far); + } + } + let most_significant_diff = most_significant_diff_so_far; + + v.push(first_input); + v.push(second_input); + v.push(most_significant_diff); + v.append(&mut first_input_chunks); + v.append(&mut second_input_chunks); + v.append(&mut equality_dummies); + v.append(&mut chunks_equal); + v.append(&mut intermediate_values); + + v.iter().map(|&x| x.into()).collect::>() + }; + + let mut rng = rand::thread_rng(); + let max: u64 = 1 << num_bits - 1; + let first_input_u64 = rng.gen_range(0..max); + let second_input_u64 = { + let mut val = rng.gen_range(0..max); + while val < first_input_u64 { + val = rng.gen_range(0..max); + } + val + }; + + let first_input = F::from_canonical_u64(first_input_u64); + let second_input = F::from_canonical_u64(second_input_u64); + + let less_than_gate = AssertLessThanGate:: { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + let less_than_vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(first_input, second_input), + public_inputs_hash: &HashOut::rand(), + }; + assert!( + less_than_gate + .eval_unfiltered(less_than_vars) + .iter() + .all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + + let equal_gate = AssertLessThanGate:: { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + let equal_vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(first_input, first_input), + public_inputs_hash: &HashOut::rand(), + }; + assert!( + equal_gate + .eval_unfiltered(equal_vars) + .iter() + .all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 988086d0..7d7feb27 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -43,33 +43,42 @@ impl, const D: usize> ComparisonGate { 1 } - pub fn wire_most_significant_diff(&self) -> usize { + pub fn wire_result_bool(&self) -> usize { 2 } + pub fn wire_most_significant_diff(&self) -> usize { + 3 + } + pub fn wire_first_chunk_val(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + chunk + 4 + chunk } pub fn wire_second_chunk_val(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + self.num_chunks + chunk + 4 + self.num_chunks + chunk } pub fn wire_equality_dummy(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + 2 * self.num_chunks + chunk + 4 + 2 * self.num_chunks + chunk } pub fn wire_chunks_equal(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + 3 * self.num_chunks + chunk + 4 + 3 * self.num_chunks + chunk } pub fn wire_intermediate_value(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + 4 * self.num_chunks + chunk + 4 + 4 * self.num_chunks + chunk + } + + /// The `bit_index`th bit of 2^n - 1 + most_significant_diff. + pub fn wire_most_significant_diff_bit(&self, bit_index: usize) -> usize { + 4 + 5 * self.num_chunks + bit_index } } @@ -137,11 +146,19 @@ impl, const D: usize> Gate for ComparisonGate let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; constraints.push(most_significant_diff - most_significant_diff_so_far); - // Range check `most_significant_diff` to be less than `chunk_size`. - let product = (0..chunk_size) - .map(|x| most_significant_diff - F::Extension::from_canonical_usize(x)) - .product(); - constraints.push(product); + let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + let bits_combined = reduce_with_powers( + &most_significant_diff_bits, + F::Extension::TWO, + ); + let two_n_minus_1 = F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE; + constraints.push((two_n_minus_1 + most_significant_diff) - bits_combined); + + // Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); constraints } @@ -205,11 +222,19 @@ impl, const D: usize> Gate for ComparisonGate let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; constraints.push(most_significant_diff - most_significant_diff_so_far); - // Range check `most_significant_diff` to be less than `chunk_size`. - let product = (0..chunk_size) - .map(|x| most_significant_diff - F::from_canonical_usize(x)) - .product(); - constraints.push(product); + let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + let bits_combined = reduce_with_powers( + &most_significant_diff_bits, + F::TWO, + ); + let two_n_minus_1 = F::from_canonical_u64(1 << self.chunk_bits()) - F::ONE; + constraints.push((two_n_minus_1 + most_significant_diff) - bits_combined); + + // Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); constraints } @@ -232,11 +257,11 @@ impl, const D: usize> Gate for ComparisonGate .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) .collect(); - let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits())); - let first_chunks_combined = - reduce_with_powers_ext_recursive(builder, &first_chunks, chunk_base); - let second_chunks_combined = - reduce_with_powers_ext_recursive(builder, &second_chunks, chunk_base); + let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits())); + let first_chunks_combined = + reduce_with_powers_ext_recursive(builder, &first_chunks, chunk_base); + let second_chunks_combined = + reduce_with_powers_ext_recursive(builder, &second_chunks, chunk_base); constraints.push(builder.sub_extension(first_chunks_combined, first_input)); constraints.push(builder.sub_extension(second_chunks_combined, second_input)); @@ -285,14 +310,22 @@ impl, const D: usize> Gate for ComparisonGate constraints .push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far)); - // Range check `most_significant_diff` to be less than `chunk_size`. - let mut product = builder.one_extension(); - for x in 0..chunk_size { - let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); - let diff = builder.sub_extension(most_significant_diff, x_f); - product = builder.mul_extension(product, diff); - } - constraints.push(product); + let most_significant_diff_bits: Vec> = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + let two = builder.two(); + let bits_combined = reduce_with_powers_ext_recursive( + builder, + &most_significant_diff_bits, + two, + ); + let two_n_minus_1 = builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE); + let sum = builder.add_extension(two_n_minus_1, most_significant_diff); + constraints.push(builder.sub_extension(sum, bits_combined)); + + // Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + constraints.push(builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()])); constraints } @@ -310,7 +343,7 @@ impl, const D: usize> Gate for ComparisonGate } fn num_wires(&self) -> usize { - self.wire_intermediate_value(self.num_chunks - 1) + 1 + 4 + 5 * self.num_chunks + (self.chunk_bits() + 1) } fn num_constants(&self) -> usize { @@ -322,7 +355,7 @@ impl, const D: usize> Gate for ComparisonGate } fn num_constraints(&self) -> usize { - 4 + 5 * self.num_chunks + 5 + 5 * self.num_chunks } } @@ -358,7 +391,7 @@ impl, const D: usize> SimpleGenerator let first_input_u64 = first_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64(); - debug_assert!(first_input_u64 < second_input_u64); + let result = F::from_canonical_usize((first_input_u64 < second_input_u64) as usize); let chunk_size = 1 << self.gate.chunk_bits(); let first_input_chunks: Vec = (0..self.gate.num_chunks) @@ -397,6 +430,19 @@ impl, const D: usize> SimpleGenerator } let most_significant_diff = most_significant_diff_so_far; + let two_n_plus_msd = ((1 << self.gate.chunk_bits()) - 1) as u64 + most_significant_diff.to_canonical_u64(); + let msd_bits: Vec = (0..self.gate.chunk_bits() + 1) + .scan(two_n_plus_msd, |acc, _| { + let tmp = *acc % 2; + *acc /= 2; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + out_buffer.set_wire( + local_wire(self.gate.wire_result_bool()), + result, + ); out_buffer.set_wire( local_wire(self.gate.wire_most_significant_diff()), most_significant_diff, @@ -420,6 +466,12 @@ impl, const D: usize> SimpleGenerator intermediate_values[i], ); } + for i in 0..self.gate.chunk_bits() + 1 { + out_buffer.set_wire( + local_wire(self.gate.wire_most_significant_diff_bit(i)), + msd_bits[i], + ); + } } } @@ -453,17 +505,20 @@ mod tests { assert_eq!(gate.wire_first_input(), 0); assert_eq!(gate.wire_second_input(), 1); - assert_eq!(gate.wire_most_significant_diff(), 2); - assert_eq!(gate.wire_first_chunk_val(0), 3); - assert_eq!(gate.wire_first_chunk_val(4), 7); - assert_eq!(gate.wire_second_chunk_val(0), 8); - assert_eq!(gate.wire_second_chunk_val(4), 12); - assert_eq!(gate.wire_equality_dummy(0), 13); - assert_eq!(gate.wire_equality_dummy(4), 17); - assert_eq!(gate.wire_chunks_equal(0), 18); - assert_eq!(gate.wire_chunks_equal(4), 22); - assert_eq!(gate.wire_intermediate_value(0), 23); - assert_eq!(gate.wire_intermediate_value(4), 27); + assert_eq!(gate.wire_result_bool(), 2); + assert_eq!(gate.wire_most_significant_diff(), 3); + assert_eq!(gate.wire_first_chunk_val(0), 4); + assert_eq!(gate.wire_first_chunk_val(4), 8); + assert_eq!(gate.wire_second_chunk_val(0), 9); + assert_eq!(gate.wire_second_chunk_val(4), 13); + assert_eq!(gate.wire_equality_dummy(0), 14); + assert_eq!(gate.wire_equality_dummy(4), 18); + assert_eq!(gate.wire_chunks_equal(0), 19); + assert_eq!(gate.wire_chunks_equal(4), 23); + assert_eq!(gate.wire_intermediate_value(0), 24); + assert_eq!(gate.wire_intermediate_value(4), 28); + assert_eq!(gate.wire_most_significant_diff_bit(0), 29); + assert_eq!(gate.wire_most_significant_diff_bit(8), 37); } #[test] @@ -499,6 +554,8 @@ mod tests { let first_input_u64 = first_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64(); + let result_bool = F::from_bool(first_input_u64 < second_input_u64); + let chunk_size = 1 << chunk_bits; let mut first_input_chunks: Vec = (0..num_chunks) .scan(first_input_u64, |acc, _| { @@ -536,14 +593,25 @@ mod tests { } let most_significant_diff = most_significant_diff_so_far; + let two_n_min_1_plus_msd = ((1 << chunk_bits) - 1) as u64 + most_significant_diff.to_canonical_u64(); + let mut msd_bits: Vec = (0..chunk_bits + 1) + .scan(two_n_min_1_plus_msd, |acc, _| { + let tmp = *acc % 2; + *acc /= 2; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + v.push(first_input); v.push(second_input); + v.push(result_bool); v.push(most_significant_diff); v.append(&mut first_input_chunks); v.append(&mut second_input_chunks); v.append(&mut equality_dummies); v.append(&mut chunks_equal); v.append(&mut intermediate_values); + v.append(&mut msd_bits); v.iter().map(|&x| x.into()).collect::>() }; diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 76066285..33d14ca3 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -4,6 +4,7 @@ pub mod arithmetic; pub mod arithmetic_u32; pub mod base_sum; +pub mod assert_le; pub mod comparison; pub mod constant; pub mod exponentiation; From 0ff6e6e0a0976984a3bf84f84546dcf8cd0a8055 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 12 Oct 2021 11:41:52 -0700 Subject: [PATCH 024/202] fmt --- src/gadgets/arithmetic_u32.rs | 2 +- src/gadgets/nonnative.rs | 22 ++++++++++++---- src/gates/assert_le.rs | 4 ++- src/gates/comparison.rs | 49 ++++++++++++++++------------------- src/gates/mod.rs | 2 +- 5 files changed, 44 insertions(+), 35 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index cb8334db..2b83b03d 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; -use crate::field::field_types::RichField; use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 418aae82..9762bd34 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -13,7 +13,7 @@ use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::bimap::bimap_from_lists; - + pub struct ForeignFieldTarget { /// These F elements are assumed to contain 32-bit values. limbs: Vec, @@ -24,18 +24,26 @@ impl, const D: usize> CircuitBuilder { pub fn order_u32_limbs(&mut self) -> Vec { let modulus = FF::order(); let limbs = modulus.to_u32_digits(); - limbs.iter().map(|&limb| self.constant_u32(F::from_canonical_u32(limb))).collect() + limbs + .iter() + .map(|&limb| self.constant_u32(F::from_canonical_u32(limb))) + .collect() } // Add two `ForeignFieldTarget`s, which we assume are both normalized. - pub fn add_nonnative(&mut self, a: ForeignFieldTarget, b: ForeignFieldTarget) -> ForeignFieldTarget { + pub fn add_nonnative( + &mut self, + a: ForeignFieldTarget, + b: ForeignFieldTarget, + ) -> ForeignFieldTarget { let num_limbs = a.limbs.len(); debug_assert!(b.limbs.len() == num_limbs); let mut combined_limbs = self.add_virtual_u32_targets(num_limbs + 1); let mut carry = self.zero_u32(); for i in 0..num_limbs { - let (new_limb, carry) = self.add_three_u32(carry.clone(), a.limbs[i].clone(), b.limbs[i].clone()); + let (new_limb, carry) = + self.add_three_u32(carry.clone(), a.limbs[i].clone(), b.limbs[i].clone()); combined_limbs[i] = new_limb; } combined_limbs[num_limbs] = carry; @@ -51,7 +59,11 @@ impl, const D: usize> CircuitBuilder { todo!() } - pub fn mul_nonnative(&mut self, a: ForeignFieldTarget, b: ForeignFieldTarget) -> ForeignFieldTarget { + pub fn mul_nonnative( + &mut self, + a: ForeignFieldTarget, + b: ForeignFieldTarget, + ) -> ForeignFieldTarget { let num_limbs = a.limbs.len(); debug_assert!(b.limbs.len() == num_limbs); diff --git a/src/gates/assert_le.rs b/src/gates/assert_le.rs index 796f94b3..46432c03 100644 --- a/src/gates/assert_le.rs +++ b/src/gates/assert_le.rs @@ -471,7 +471,9 @@ mod tests { let num_bits = 40; let num_chunks = 5; - test_low_degree::(AssertLessThanGate::<_, 4>::new(num_bits, num_chunks)) + test_low_degree::(AssertLessThanGate::<_, 4>::new( + num_bits, num_chunks, + )) } #[test] diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 7d7feb27..e076a154 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -149,11 +149,9 @@ impl, const D: usize> Gate for ComparisonGate let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) .collect(); - let bits_combined = reduce_with_powers( - &most_significant_diff_bits, - F::Extension::TWO, - ); - let two_n_minus_1 = F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE; + let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO); + let two_n_minus_1 = + F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE; constraints.push((two_n_minus_1 + most_significant_diff) - bits_combined); // Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. @@ -225,10 +223,7 @@ impl, const D: usize> Gate for ComparisonGate let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) .collect(); - let bits_combined = reduce_with_powers( - &most_significant_diff_bits, - F::TWO, - ); + let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO); let two_n_minus_1 = F::from_canonical_u64(1 << self.chunk_bits()) - F::ONE; constraints.push((two_n_minus_1 + most_significant_diff) - bits_combined); @@ -257,11 +252,11 @@ impl, const D: usize> Gate for ComparisonGate .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) .collect(); - let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits())); - let first_chunks_combined = - reduce_with_powers_ext_recursive(builder, &first_chunks, chunk_base); - let second_chunks_combined = - reduce_with_powers_ext_recursive(builder, &second_chunks, chunk_base); + let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits())); + let first_chunks_combined = + reduce_with_powers_ext_recursive(builder, &first_chunks, chunk_base); + let second_chunks_combined = + reduce_with_powers_ext_recursive(builder, &second_chunks, chunk_base); constraints.push(builder.sub_extension(first_chunks_combined, first_input)); constraints.push(builder.sub_extension(second_chunks_combined, second_input)); @@ -314,18 +309,19 @@ impl, const D: usize> Gate for ComparisonGate .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) .collect(); let two = builder.two(); - let bits_combined = reduce_with_powers_ext_recursive( - builder, - &most_significant_diff_bits, - two, + let bits_combined = + reduce_with_powers_ext_recursive(builder, &most_significant_diff_bits, two); + let two_n_minus_1 = builder.constant_extension( + F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE, ); - let two_n_minus_1 = builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE); let sum = builder.add_extension(two_n_minus_1, most_significant_diff); constraints.push(builder.sub_extension(sum, bits_combined)); // Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. let result_bool = vars.local_wires[self.wire_result_bool()]; - constraints.push(builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()])); + constraints.push( + builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()]), + ); constraints } @@ -430,7 +426,8 @@ impl, const D: usize> SimpleGenerator } let most_significant_diff = most_significant_diff_so_far; - let two_n_plus_msd = ((1 << self.gate.chunk_bits()) - 1) as u64 + most_significant_diff.to_canonical_u64(); + let two_n_plus_msd = + ((1 << self.gate.chunk_bits()) - 1) as u64 + most_significant_diff.to_canonical_u64(); let msd_bits: Vec = (0..self.gate.chunk_bits() + 1) .scan(two_n_plus_msd, |acc, _| { let tmp = *acc % 2; @@ -439,10 +436,7 @@ impl, const D: usize> SimpleGenerator }) .collect(); - out_buffer.set_wire( - local_wire(self.gate.wire_result_bool()), - result, - ); + out_buffer.set_wire(local_wire(self.gate.wire_result_bool()), result); out_buffer.set_wire( local_wire(self.gate.wire_most_significant_diff()), most_significant_diff, @@ -468,7 +462,7 @@ impl, const D: usize> SimpleGenerator } for i in 0..self.gate.chunk_bits() + 1 { out_buffer.set_wire( - local_wire(self.gate.wire_most_significant_diff_bit(i)), + local_wire(self.gate.wire_most_significant_diff_bit(i)), msd_bits[i], ); } @@ -593,7 +587,8 @@ mod tests { } let most_significant_diff = most_significant_diff_so_far; - let two_n_min_1_plus_msd = ((1 << chunk_bits) - 1) as u64 + most_significant_diff.to_canonical_u64(); + let two_n_min_1_plus_msd = + ((1 << chunk_bits) - 1) as u64 + most_significant_diff.to_canonical_u64(); let mut msd_bits: Vec = (0..chunk_bits + 1) .scan(two_n_min_1_plus_msd, |acc, _| { let tmp = *acc % 2; diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 33d14ca3..96bfc4a1 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -3,8 +3,8 @@ pub mod arithmetic; pub mod arithmetic_u32; -pub mod base_sum; pub mod assert_le; +pub mod base_sum; pub mod comparison; pub mod constant; pub mod exponentiation; From 26959d11c907b289ff47d344721233e57aa935e1 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 12 Oct 2021 11:50:00 -0700 Subject: [PATCH 025/202] range-check the bits --- src/gates/comparison.rs | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index e076a154..083f923a 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -149,6 +149,14 @@ impl, const D: usize> Gate for ComparisonGate let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) .collect(); + + // Range-check the bits. + for i in 0..most_significant_diff_bits.len() { + constraints.push( + most_significant_diff_bits[i] * (F::Extension::ONE - most_significant_diff_bits[i]), + ); + } + let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO); let two_n_minus_1 = F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE; @@ -223,6 +231,13 @@ impl, const D: usize> Gate for ComparisonGate let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) .collect(); + + // Range-check the bits. + for i in 0..most_significant_diff_bits.len() { + constraints + .push(most_significant_diff_bits[i] * (F::ONE - most_significant_diff_bits[i])); + } + let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO); let two_n_minus_1 = F::from_canonical_u64(1 << self.chunk_bits()) - F::ONE; constraints.push((two_n_minus_1 + most_significant_diff) - bits_combined); @@ -308,6 +323,14 @@ impl, const D: usize> Gate for ComparisonGate let most_significant_diff_bits: Vec> = (0..self.chunk_bits() + 1) .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) .collect(); + + // Range-check the bits. + for i in 0..most_significant_diff_bits.len() { + let this_bit = most_significant_diff_bits[i]; + let inverse = builder.sub_extension(one, this_bit); + constraints.push(builder.mul_extension(this_bit, inverse)); + } + let two = builder.two(); let bits_combined = reduce_with_powers_ext_recursive(builder, &most_significant_diff_bits, two); @@ -351,7 +374,7 @@ impl, const D: usize> Gate for ComparisonGate } fn num_constraints(&self) -> usize { - 5 + 5 * self.num_chunks + 6 + 5 * self.num_chunks + self.chunk_bits() } } From 6dd14eb27ade6ba91e5222e83ac42caaa2e6824d Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 12 Oct 2021 13:31:05 -0700 Subject: [PATCH 026/202] comparison gate should also be <= --- src/gates/comparison.rs | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 083f923a..9783a42c 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -158,11 +158,10 @@ impl, const D: usize> Gate for ComparisonGate } let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO); - let two_n_minus_1 = - F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE; - constraints.push((two_n_minus_1 + most_significant_diff) - bits_combined); + let two_n = F::Extension::from_canonical_u64(1 << self.chunk_bits()); + constraints.push((two_n + most_significant_diff) - bits_combined); - // Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + // Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1. let result_bool = vars.local_wires[self.wire_result_bool()]; constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); @@ -239,10 +238,10 @@ impl, const D: usize> Gate for ComparisonGate } let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO); - let two_n_minus_1 = F::from_canonical_u64(1 << self.chunk_bits()) - F::ONE; - constraints.push((two_n_minus_1 + most_significant_diff) - bits_combined); + let two_n = F::from_canonical_u64(1 << self.chunk_bits()); + constraints.push((two_n + most_significant_diff) - bits_combined); - // Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + // Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. let result_bool = vars.local_wires[self.wire_result_bool()]; constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); @@ -334,13 +333,12 @@ impl, const D: usize> Gate for ComparisonGate let two = builder.two(); let bits_combined = reduce_with_powers_ext_recursive(builder, &most_significant_diff_bits, two); - let two_n_minus_1 = builder.constant_extension( - F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE, - ); - let sum = builder.add_extension(two_n_minus_1, most_significant_diff); + let two_n = + builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits())); + let sum = builder.add_extension(two_n, most_significant_diff); constraints.push(builder.sub_extension(sum, bits_combined)); - // Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + // Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1. let result_bool = vars.local_wires[self.wire_result_bool()]; constraints.push( builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()]), @@ -410,7 +408,7 @@ impl, const D: usize> SimpleGenerator let first_input_u64 = first_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64(); - let result = F::from_canonical_usize((first_input_u64 < second_input_u64) as usize); + let result = F::from_canonical_usize((first_input_u64 <= second_input_u64) as usize); let chunk_size = 1 << self.gate.chunk_bits(); let first_input_chunks: Vec = (0..self.gate.num_chunks) @@ -450,7 +448,7 @@ impl, const D: usize> SimpleGenerator let most_significant_diff = most_significant_diff_so_far; let two_n_plus_msd = - ((1 << self.gate.chunk_bits()) - 1) as u64 + most_significant_diff.to_canonical_u64(); + (1 << self.gate.chunk_bits()) as u64 + most_significant_diff.to_canonical_u64(); let msd_bits: Vec = (0..self.gate.chunk_bits() + 1) .scan(two_n_plus_msd, |acc, _| { let tmp = *acc % 2; @@ -571,7 +569,7 @@ mod tests { let first_input_u64 = first_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64(); - let result_bool = F::from_bool(first_input_u64 < second_input_u64); + let result_bool = F::from_bool(first_input_u64 <= second_input_u64); let chunk_size = 1 << chunk_bits; let mut first_input_chunks: Vec = (0..num_chunks) @@ -610,10 +608,10 @@ mod tests { } let most_significant_diff = most_significant_diff_so_far; - let two_n_min_1_plus_msd = - ((1 << chunk_bits) - 1) as u64 + most_significant_diff.to_canonical_u64(); + let two_n_plus_msd = + (1 << chunk_bits) as u64 + most_significant_diff.to_canonical_u64(); let mut msd_bits: Vec = (0..chunk_bits + 1) - .scan(two_n_min_1_plus_msd, |acc, _| { + .scan(two_n_plus_msd, |acc, _| { let tmp = *acc % 2; *acc /= 2; Some(F::from_canonical_u64(tmp)) From bdfe124b0cceda186fcd62f5d44bc6398e14442d Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 12 Oct 2021 14:20:39 -0700 Subject: [PATCH 027/202] multiple comparison --- src/gadgets/multiple_comparison.rs | 61 ++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 src/gadgets/multiple_comparison.rs diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs new file mode 100644 index 00000000..5291323d --- /dev/null +++ b/src/gadgets/multiple_comparison.rs @@ -0,0 +1,61 @@ +use std::marker::PhantomData; + +use itertools::izip; + +use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, RichField}; +use crate::gates::comparison::ComparisonGate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; +use crate::iop::target::{BoolTarget, Target}; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::ceil_div_usize; + +impl, const D: usize> CircuitBuilder { + /// Returns true if a is less than or equal to b, considered as limbs of a large value. + pub fn compare_lists(&mut self, a: Vec, b: Vec, num_bits: usize) -> BoolTarget { + assert_eq!( + a.len(), + b.len(), + "Permutation must have same number of inputs and outputs" + ); + let n = a.len(); + + let chunk_size = 4; + let num_chunks = ceil_div_usize(num_bits, 4); + + let one = self.one(); + let mut result = self.one(); + for i in 0..n { + let a_le_b_gate = ComparisonGate::new(num_bits, num_chunks); + let a_le_b_gate_index = self.add_gate(a_le_b_gate.clone(), vec![]); + self.connect( + Target::wire(a_le_b_gate_index, a_le_b_gate.wire_first_input()), + a[i], + ); + self.connect( + Target::wire(a_le_b_gate_index, a_le_b_gate.wire_second_input()), + b[i], + ); + let a_le_b_result = Target::wire(a_le_b_gate_index, a_le_b_gate.wire_result_bool()); + + let b_le_a_gate = ComparisonGate::new(num_bits, num_chunks); + let b_le_a_gate_index = self.add_gate(b_le_a_gate.clone(), vec![]); + self.connect( + Target::wire(b_le_a_gate_index, b_le_a_gate.wire_first_input()), + b[i], + ); + self.connect( + Target::wire(b_le_a_gate_index, b_le_a_gate.wire_second_input()), + a[i], + ); + let b_le_a_result = Target::wire(b_le_a_gate_index, b_le_a_gate.wire_result_bool()); + + let these_limbs_equal = self.mul(a_le_b_result, b_le_a_result); + let these_limbs_less_than = self.sub(one, b_le_a_result); + result = self.mul_add(these_limbs_equal, result, these_limbs_less_than); + } + + BoolTarget::new_unsafe(result) + } +} From 3fff08aa80e790b4fcb34b3d2e024ae515f263f3 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 13 Oct 2021 14:01:42 -0700 Subject: [PATCH 028/202] U32 subtraction gate --- src/gates/mod.rs | 1 + src/gates/subtraction_u32.rs | 427 +++++++++++++++++++++++++++++++++++ 2 files changed, 428 insertions(+) create mode 100644 src/gates/subtraction_u32.rs diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 96bfc4a1..b1a6028e 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -19,6 +19,7 @@ pub(crate) mod poseidon_mds; pub(crate) mod public_input; pub mod random_access; pub mod reducing; +pub mod subtraction_u32; pub mod switch; #[cfg(test)] diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs new file mode 100644 index 00000000..de79b67f --- /dev/null +++ b/src/gates/subtraction_u32.rs @@ -0,0 +1,427 @@ +use std::marker::PhantomData; + +use itertools::unfold; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, RichField}; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// Number of arithmetic operations performed by an arithmetic gate. +pub const NUM_U32_SUBTRACTION_OPS: usize = 3; + +/// A gate to perform a subtraction . +#[derive(Clone, Debug)] +pub struct U32SubtractionGate, const D: usize> { + _phantom: PhantomData, +} + +impl, const D: usize> U32SubtractionGate { + pub fn new() -> Self { + Self { + _phantom: PhantomData, + } + } + + pub fn wire_ith_input_x(i: usize) -> usize { + debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + 5 * i + } + pub fn wire_ith_input_y(i: usize) -> usize { + debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + 5 * i + 1 + } + pub fn wire_ith_input_borrow(i: usize) -> usize { + debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + 5 * i + 2 + } + + pub fn wire_ith_output_result(i: usize) -> usize { + debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + 5 * i + 3 + } + pub fn wire_ith_output_borrow(i: usize) -> usize { + debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + 5 * i + 4 + } + + // We have limbs ony for the first half of the output. + pub fn limb_bits() -> usize { + 2 + } + pub fn num_limbs() -> usize { + 32 / Self::limb_bits() + } + + pub fn wire_ith_output_jth_limb(i: usize, j: usize) -> usize { + debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + debug_assert!(j < Self::num_limbs()); + 5 * NUM_U32_SUBTRACTION_OPS + Self::num_limbs() * i + j + } +} + +impl, const D: usize> Gate for U32SubtractionGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..NUM_U32_SUBTRACTION_OPS { + let input_x = vars.local_wires[Self::wire_ith_input_x(i)]; + let input_y = vars.local_wires[Self::wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[Self::wire_ith_input_borrow(i)]; + + let result_initial = input_x - input_y - input_borrow; + let base = F::Extension::from_canonical_u64(1 << 32u64); + + let output_result = vars.local_wires[Self::wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[Self::wire_ith_output_borrow(i)]; + + constraints.push(output_result - (result_initial + base * output_borrow)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = F::Extension::ZERO; + let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(product); + + combined_limbs = limb_base * combined_limbs + this_limb; + } + constraints.push(combined_limbs - output_result); + + // Range-check output_borrow to be one bit. + constraints.push(output_borrow * (F::Extension::ONE - output_borrow)); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..NUM_U32_SUBTRACTION_OPS { + let input_x = vars.local_wires[Self::wire_ith_input_x(i)]; + let input_y = vars.local_wires[Self::wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[Self::wire_ith_input_borrow(i)]; + + let result_initial = input_x - input_y - input_borrow; + let base = F::from_canonical_u64(1 << 32u64); + + let output_result = vars.local_wires[Self::wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[Self::wire_ith_output_borrow(i)]; + + constraints.push(output_result - (result_initial + base * output_borrow)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = F::ZERO; + let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::from_canonical_usize(x)) + .product(); + constraints.push(product); + + combined_limbs = limb_base * combined_limbs + this_limb; + } + constraints.push(combined_limbs - output_result); + + // Range-check output_borrow to be one bit. + constraints.push(output_borrow * (F::ONE - output_borrow)); + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..NUM_U32_SUBTRACTION_OPS { + let input_x = vars.local_wires[Self::wire_ith_input_x(i)]; + let input_y = vars.local_wires[Self::wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[Self::wire_ith_input_borrow(i)]; + + let diff = builder.sub_extension(input_x, input_y); + let result_initial = builder.sub_extension(diff, input_borrow); + let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32u64)); + + let output_result = vars.local_wires[Self::wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[Self::wire_ith_output_borrow(i)]; + + let computed_output = builder.mul_add_extension(base, output_borrow, result_initial); + constraints.push(builder.sub_extension(output_result, computed_output)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = builder.zero_extension(); + let limb_base = builder.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let mut product = builder.one_extension(); + for x in 0..max_limb { + let x_target = + builder.constant_extension(F::Extension::from_canonical_usize(x)); + let diff = builder.sub_extension(this_limb, x_target); + product = builder.mul_extension(product, diff); + } + constraints.push(product); + + combined_limbs = builder.mul_add_extension(limb_base, combined_limbs, this_limb); + } + constraints.push(builder.sub_extension(combined_limbs, output_result)); + + // Range-check output_borrow to be one bit. + let one = builder.one_extension(); + let not_borrow = builder.sub_extension(one, output_borrow); + constraints.push(builder.mul_extension(output_borrow, not_borrow)); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + (0..NUM_U32_SUBTRACTION_OPS) + .map(|i| { + let g: Box> = Box::new( + U32SubtractionGenerator { + gate_index, + i, + _phantom: PhantomData, + } + .adapter(), + ); + g + }) + .collect::>() + } + + fn num_wires(&self) -> usize { + NUM_U32_SUBTRACTION_OPS * (5 + Self::num_limbs()) + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 << Self::limb_bits() + } + + fn num_constraints(&self) -> usize { + NUM_U32_SUBTRACTION_OPS * (3 + Self::num_limbs()) + } +} + +#[derive(Clone, Debug)] +struct U32SubtractionGenerator, const D: usize> { + gate_index: usize, + i: usize, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for U32SubtractionGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |input| Target::wire(self.gate_index, input); + + let mut deps = Vec::with_capacity(3); + deps.push(local_target( + U32SubtractionGate::::wire_ith_input_x(self.i), + )); + deps.push(local_target( + U32SubtractionGate::::wire_ith_input_y(self.i), + )); + deps.push(local_target(U32SubtractionGate::::wire_ith_input_borrow( + self.i, + ))); + deps + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let get_local_wire = |input| witness.get_wire(local_wire(input)); + + let input_x = + get_local_wire(U32SubtractionGate::::wire_ith_input_x(self.i)); + let input_y = + get_local_wire(U32SubtractionGate::::wire_ith_input_y(self.i)); + let input_borrow = get_local_wire(U32SubtractionGate::::wire_ith_input_borrow(self.i)); + + let result_initial = input_x - input_y - input_borrow; + let result_initial_u64 = result_initial.to_canonical_u64(); + let output_borrow = if result_initial_u64 > 1 << 32u64 { + F::ONE + } else { + F::ZERO + }; + + let base = F::from_canonical_u64(1 << 32u64); + let output_result = result_initial + base * output_borrow; + + let output_result_wire = + local_wire(U32SubtractionGate::::wire_ith_output_result(self.i)); + let output_borrow_wire = + local_wire(U32SubtractionGate::::wire_ith_output_borrow(self.i)); + + out_buffer.set_wire(output_result_wire, output_result); + out_buffer.set_wire(output_borrow_wire, output_borrow); + + let output_result_u64 = output_result.to_canonical_u64(); + + let num_limbs = U32SubtractionGate::::num_limbs(); + let limb_base = 1 << U32SubtractionGate::::limb_bits(); + let output_limbs: Vec<_> = (0..num_limbs) + .scan(output_result_u64, |acc, _| { + let tmp = *acc % limb_base; + *acc /= limb_base; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + for j in 0..num_limbs { + let wire = local_wire(U32SubtractionGate::::wire_ith_output_jth_limb( + self.i, j, + )); + out_buffer.set_wire(wire, output_limbs[j]); + } + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use rand::Rng; + + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::{Field, PrimeField}; + use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::hash::hash_types::HashOut; + use crate::plonk::vars::EvaluationVars; + + #[test] + fn low_degree() { + test_low_degree::(U32SubtractionGate:: { + _phantom: PhantomData, + }) + } + + #[test] + fn eval_fns() -> Result<()> { + test_eval_fns::(U32SubtractionGate:: { + _phantom: PhantomData, + }) + } + + #[test] + fn test_gate_constraint() { + type F = CrandallField; + type FF = QuarticExtension; + const D: usize = 4; + + fn get_wires( + inputs_x: Vec, + inputs_y: Vec, + borrows: Vec, + ) -> Vec { + let mut v0 = Vec::new(); + let mut v1 = Vec::new(); + + let limb_bits = U32SubtractionGate::::limb_bits(); + let num_limbs = U32SubtractionGate::::num_limbs(); + let limb_base = 1 << limb_bits; + for c in 0..NUM_U32_SUBTRACTION_OPS { + let input_x = F::from_canonical_u64(inputs_x[c]); + let input_y = F::from_canonical_u64(inputs_y[c]); + let input_borrow = F::from_canonical_u64(borrows[c]); + + let result_initial = input_x - input_y - input_borrow; + let result_initial_u64 = result_initial.to_canonical_u64(); + let output_borrow = if result_initial_u64 > 1 << 32u64 { + F::ONE + } else { + F::ZERO + }; + + let base = F::from_canonical_u64(1 << 32u64); + let output_result = result_initial + base * output_borrow; + + let output_result_u64 = output_result.to_canonical_u64(); + + let mut output_limbs: Vec<_> = (0..num_limbs) + .scan(output_result_u64, |acc, _| { + let tmp = *acc % limb_base; + *acc /= limb_base; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + v0.push(input_x); + v0.push(input_y); + v0.push(input_borrow); + v0.push(output_result); + v0.push(output_borrow); + v1.append(&mut output_limbs); + } + + v0.iter() + .chain(v1.iter()) + .map(|&x| x.into()) + .collect::>() + } + + let mut rng = rand::thread_rng(); + let inputs_x: Vec<_> = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + let inputs_y: Vec<_> = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + let borrows: Vec<_> = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + + let gate = U32SubtractionGate:: { + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(inputs_x, inputs_y, borrows), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} From 18567e570bb3bd363993ca36f0d3a5c2c46c077b Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 09:56:21 -0800 Subject: [PATCH 029/202] merge --- src/gadgets/arithmetic_extension.rs | 27 ---- src/gadgets/arithmetic_u32.rs | 55 ++++++-- src/gadgets/permutation.rs | 21 +-- src/plonk/circuit_builder.rs | 198 +++++++++++++++++++++------- 4 files changed, 196 insertions(+), 105 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 24499760..e2654dcc 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -12,33 +12,6 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::bits_u64; impl, const D: usize> CircuitBuilder { - /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. - /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index - /// `g` and the gate's `i`-th operation is available. - fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { - let (gate, i) = self - .free_arithmetic - .get(&(const_0, const_1)) - .copied() - .unwrap_or_else(|| { - let gate = self.add_gate( - ArithmeticExtensionGate::new_from_config(&self.config), - vec![const_0, const_1], - ); - (gate, 0) - }); - - // Update `free_arithmetic` with new values. - if i < ArithmeticExtensionGate::::num_ops(&self.config) - 1 { - self.free_arithmetic - .insert((const_0, const_1), (gate, i + 1)); - } else { - self.free_arithmetic.remove(&(const_0, const_1)); - } - - (gate, i) - } - pub fn arithmetic_extension( &mut self, const_0: F, diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index 2b83b03d..4f60dde7 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -36,14 +36,7 @@ impl, const D: usize> CircuitBuilder { y: U32Target, z: U32Target, ) -> (U32Target, U32Target) { - let (gate_index, copy) = match self.current_u32_arithmetic_gate { - None => { - let gate = U32ArithmeticGate::new(); - let gate_index = self.add_gate(gate, vec![]); - (gate_index, 0) - } - Some((gate_index, copy)) => (gate_index, copy), - }; + let (gate_index, copy) = self.find_u32_arithmetic_gate(); self.connect( Target::wire( @@ -73,12 +66,6 @@ impl, const D: usize> CircuitBuilder { U32ArithmeticGate::::wire_ith_output_high_half(copy), )); - if copy == NUM_U32_ARITHMETIC_OPS - 1 { - self.current_u32_arithmetic_gate = None; - } else { - self.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); - } - (output_low, output_high) } @@ -103,4 +90,44 @@ impl, const D: usize> CircuitBuilder { let zero = self.zero_u32(); self.mul_add_u32(a, b, zero) } + + // Returns x * y + z. + pub fn sub_u32( + &mut self, + x: U32Target, + y: U32Target, + borrow: U32Target, + ) -> (U32Target, U32Target) { + let (gate_index, copy) = self.find_u32_subtraction_gate(); + + self.connect( + Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_multiplicand_0(copy), + ), + x.0, + ); + self.connect( + Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_multiplicand_1(copy), + ), + y.0, + ); + self.connect( + Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(copy)), + z.0, + ); + + let output_low = U32Target(Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_output_low_half(copy), + )); + let output_high = U32Target(Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_output_high_half(copy), + )); + + (output_low, output_high) + } } diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index fd4a897f..a0c9b087 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -73,20 +73,8 @@ impl, const D: usize> CircuitBuilder { let chunk_size = a1.len(); - if self.current_switch_gates.len() < chunk_size { - self.current_switch_gates - .extend(vec![None; chunk_size - self.current_switch_gates.len()]); - } - let (gate, gate_index, mut next_copy) = - match self.current_switch_gates[chunk_size - 1].clone() { - None => { - let gate = SwitchGate::::new_from_config(&self.config, chunk_size); - let gate_index = self.add_gate(gate.clone(), vec![]); - (gate, gate_index, 0) - } - Some((gate, idx, next_copy)) => (gate, idx, next_copy), - }; + self.find_switch_gate(chunk_size); let num_copies = gate.num_copies; @@ -113,13 +101,6 @@ impl, const D: usize> CircuitBuilder { let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy)); - next_copy += 1; - if next_copy == num_copies { - self.current_switch_gates[chunk_size - 1] = None; - } else { - self.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy)); - } - (switch, c, d) } diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 019bc71e..f915493a 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -15,6 +15,8 @@ use crate::fri::{FriConfig, FriParams}; use crate::gadgets::arithmetic_extension::ArithmeticOperation; use crate::gadgets::arithmetic_u32::U32Target; use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::gates::arithmetic_u32::{NUM_U32_ARITHMETIC_OPS, U32ArithmeticGate}; +use crate::gates::subtraction_u32::{NUM_U32_SUBTRACTION_OPS, U32SubtractionGate}; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; @@ -75,24 +77,7 @@ pub struct CircuitBuilder, const D: usize> { /// Memoized results of `arithmetic_extension` calls. pub(crate) arithmetic_results: HashMap, ExtensionTarget>, - /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using - /// these constants with gate index `g` and already using `i` arithmetic operations. - pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, - - /// A map `(c0, c1) -> (g, i)` from constants `vec_size` to an available arithmetic gate using - /// these constants with gate index `g` and already using `i` random accesses. - pub(crate) free_random_access: HashMap, - - // `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value - // chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies - // of switches - pub(crate) current_switch_gates: Vec, usize, usize)>>, - - /// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one) - pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>, - - /// An available `ConstantGate` instance, if any. - free_constant: Option<(usize, usize)>, + batched_gates: BatchedGates } impl, const D: usize> CircuitBuilder { @@ -110,11 +95,7 @@ impl, const D: usize> CircuitBuilder { constants_to_targets: HashMap::new(), arithmetic_results: HashMap::new(), targets_to_constants: HashMap::new(), - free_arithmetic: HashMap::new(), - free_random_access: HashMap::new(), - current_switch_gates: Vec::new(), - current_u32_arithmetic_gate: None, - free_constant: None, + batched_gates: BatchedGates::new(), }; builder.check_config(); builder @@ -308,7 +289,7 @@ impl, const D: usize> CircuitBuilder { return target; } - let (gate, instance) = self.constant_gate_instance(); + let (gate, instance) = self.batched_gates.constant_gate_instance(); let target = Target::wire(gate, instance); self.gate_instances[gate].constants[instance] = c; @@ -318,26 +299,6 @@ impl, const D: usize> CircuitBuilder { target } - /// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a - /// new `ConstantGate` if needed. - fn constant_gate_instance(&mut self) -> (usize, usize) { - if self.free_constant.is_none() { - let num_consts = self.config.constant_gate_size; - // We will fill this `ConstantGate` with zero constants initially. - // These will be overwritten by `constant` as the gate instances are filled. - let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]); - self.free_constant = Some((gate, 0)); - } - - let (gate, instance) = self.free_constant.unwrap(); - if instance + 1 < self.config.constant_gate_size { - self.free_constant = Some((gate, instance + 1)); - } else { - self.free_constant = None; - } - (gate, instance) - } - pub fn constants(&mut self, constants: &[F]) -> Vec { constants.iter().map(|&c| self.constant(c)).collect() } @@ -846,3 +807,152 @@ impl, const D: usize> CircuitBuilder { } } } + +/// +pub struct BatchedGates, const D: usize> { + /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using + /// these constants with gate index `g` and already using `i` arithmetic operations. + pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, + + /// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value + /// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies + /// of switches + pub(crate) current_switch_gates: Vec, usize, usize)>>, + + /// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one) + pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>, + + /// The `U32SubtractionGate` currently being filled (so new u32 subtraction operations will be added to this gate before creating a new one) + pub(crate) current_u32_subtraction_gate: Option<(usize, usize)>, + + /// An available `ConstantGate` instance, if any. + pub(crate) free_constant: Option<(usize, usize)>, +} + +impl, const D: usize> BatchedGates { + pub fn new() -> Self { + Self { + free_arithmetic: HashMap::new(), + current_switch_gates: Vec::new(), + current_u32_arithmetic_gate: None, + current_u32_subtraction_gate: None, + free_constant: None, + } + } +} + +impl, const D: usize> CircuitBuilder { + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_arithmetic + .get(&(const_0, const_1)) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + ArithmeticExtensionGate::new_from_config(&self.config), + vec![const_0, const_1], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < ArithmeticExtensionGate::::num_ops(&self.config) - 1 { + self.batched_gates + .free_arithmetic + .insert((const_0, const_1), (gate, i + 1)); + } else { + self.batched_gates.free_arithmetic.remove(&(const_0, const_1)); + } + + (gate, i) + } + + pub fn find_switch_gate(&mut self, chunk_size: usize) -> (SwitchGate, usize, usize) { + if self.batched_gates.current_switch_gates.len() < chunk_size { + self.batched_gates.current_switch_gates + .extend(vec![None; chunk_size - self.current_switch_gates.len()]); + } + + let (gate, gate_index, mut next_copy) = + match self.current_switch_gates[chunk_size - 1].clone() { + None => { + let gate = SwitchGate::::new_from_config(self.config.clone(), chunk_size); + let gate_index = self.add_gate(gate.clone(), vec![]); + (gate, gate_index, 0) + } + Some((gate, idx, next_copy)) => (gate, idx, next_copy), + }; + + let num_copies = gate.num_copies; + + if next_copy == num_copies { + self.batched_gates.current_switch_gates[chunk_size - 1] = None; + } else { + self.batched_gates.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy + 1)); + } + + (gate, gate_index, next_copy) + } + + pub fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) { + let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate { + None => { + let gate = U32ArithmeticGate::new(); + let gate_index = self.add_gate(gate, vec![]); + (gate_index, 0) + } + Some((gate_index, copy)) => (gate_index, copy), + }; + + if copy == NUM_U32_ARITHMETIC_OPS - 1 { + self.current_u32_arithmetic_gate = None; + } else { + self.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); + } + + (gate_index, copy) + } + + pub fn find_u32_subtraction_gate(&mut self) -> (usize, usize) { + let (gate_index, copy) = match self.batched_gates.current_u32_subtraction_gate { + None => { + let gate = U32SubtractionGate::new(); + let gate_index = self.add_gate(gate, vec![]); + (gate_index, 0) + } + Some((gate_index, copy)) => (gate_index, copy), + }; + + if copy == NUM_U32_SUBTRACTION_OPS - 1 { + self.current_u32_subtraction_gate = None; + } else { + self.current_u32_subtraction_gate = Some((gate_index, copy + 1)); + } + + (gate_index, copy) + } + + /// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a + /// new `ConstantGate` if needed. + fn constant_gate_instance(&mut self) -> (usize, usize) { + if self.free_constant.is_none() { + let num_consts = self.config.constant_gate_size; + // We will fill this `ConstantGate` with zero constants initially. + // These will be overwritten by `constant` as the gate instances are filled. + let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]); + self.free_constant = Some((gate, 0)); + } + + let (gate, instance) = self.free_constant.unwrap(); + if instance + 1 < self.config.constant_gate_size { + self.free_constant = Some((gate, instance + 1)); + } else { + self.free_constant = None; + } + (gate, instance) + } +} From cc48abff94502500a36505be78c044bf783cf636 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 14 Oct 2021 13:41:51 -0700 Subject: [PATCH 030/202] sub --- src/gadgets/arithmetic_u32.rs | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index 4f60dde7..91c658b1 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; +use crate::gates::subtraction_u32::U32SubtractionGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; @@ -91,7 +92,7 @@ impl, const D: usize> CircuitBuilder { self.mul_add_u32(a, b, zero) } - // Returns x * y + z. + // Returns x - y - borrow, as a pair (result, borrow), where borrow is 0 or 1 depending on whether borrowing from the next digit is required (iff y + borrow > x). pub fn sub_u32( &mut self, x: U32Target, @@ -103,31 +104,34 @@ impl, const D: usize> CircuitBuilder { self.connect( Target::wire( gate_index, - U32ArithmeticGate::::wire_ith_multiplicand_0(copy), + U32SubtractionGate::::wire_ith_input_x(copy), ), x.0, ); self.connect( Target::wire( gate_index, - U32ArithmeticGate::::wire_ith_multiplicand_1(copy), + U32SubtractionGate::::wire_ith_input_y(copy), ), y.0, ); self.connect( - Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(copy)), - z.0, + Target::wire( + gate_index, + U32SubtractionGate::::wire_ith_input_borrow(copy), + ), + borrow.0, ); - let output_low = U32Target(Target::wire( + let output_result = U32Target(Target::wire( gate_index, - U32ArithmeticGate::::wire_ith_output_low_half(copy), + U32SubtractionGate::::wire_ith_output_result(copy), )); - let output_high = U32Target(Target::wire( + let output_borrow = U32Target(Target::wire( gate_index, - U32ArithmeticGate::::wire_ith_output_high_half(copy), + U32SubtractionGate::::wire_ith_output_borrow(copy), )); - (output_low, output_high) + (output_result, output_borrow) } } From 97f66b58f5660b76e848fd08836b1497fae85129 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 09:56:42 -0800 Subject: [PATCH 031/202] merge --- src/plonk/circuit_builder.rs | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index f915493a..898b20a5 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -15,14 +15,14 @@ use crate::fri::{FriConfig, FriParams}; use crate::gadgets::arithmetic_extension::ArithmeticOperation; use crate::gadgets::arithmetic_u32::U32Target; use crate::gates::arithmetic::ArithmeticExtensionGate; -use crate::gates::arithmetic_u32::{NUM_U32_ARITHMETIC_OPS, U32ArithmeticGate}; -use crate::gates::subtraction_u32::{NUM_U32_SUBTRACTION_OPS, U32SubtractionGate}; +use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; use crate::gates::random_access::RandomAccessGate; +use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; use crate::gates::switch::SwitchGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget}; use crate::hash::hashing::hash_n_to_hash; @@ -808,7 +808,7 @@ impl, const D: usize> CircuitBuilder { } } -/// +/// Various gate types can contain multiple copies in a single Gate. This helper struct lets a CircuitBuilder track such gates that are currently being "filled up." pub struct BatchedGates, const D: usize> { /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using /// these constants with gate index `g` and already using `i` arithmetic operations. @@ -865,7 +865,9 @@ impl, const D: usize> CircuitBuilder { .free_arithmetic .insert((const_0, const_1), (gate, i + 1)); } else { - self.batched_gates.free_arithmetic.remove(&(const_0, const_1)); + self.batched_gates + .free_arithmetic + .remove(&(const_0, const_1)); } (gate, i) @@ -873,12 +875,14 @@ impl, const D: usize> CircuitBuilder { pub fn find_switch_gate(&mut self, chunk_size: usize) -> (SwitchGate, usize, usize) { if self.batched_gates.current_switch_gates.len() < chunk_size { - self.batched_gates.current_switch_gates - .extend(vec![None; chunk_size - self.current_switch_gates.len()]); + self.batched_gates.current_switch_gates.extend(vec![ + None; + chunk_size - self.batched_gates.current_switch_gates.len() + ]); } let (gate, gate_index, mut next_copy) = - match self.current_switch_gates[chunk_size - 1].clone() { + match self.batched_gates.current_switch_gates[chunk_size - 1].clone() { None => { let gate = SwitchGate::::new_from_config(self.config.clone(), chunk_size); let gate_index = self.add_gate(gate.clone(), vec![]); @@ -886,13 +890,14 @@ impl, const D: usize> CircuitBuilder { } Some((gate, idx, next_copy)) => (gate, idx, next_copy), }; - + let num_copies = gate.num_copies; - + if next_copy == num_copies { self.batched_gates.current_switch_gates[chunk_size - 1] = None; } else { - self.batched_gates.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy + 1)); + self.batched_gates.current_switch_gates[chunk_size - 1] = + Some((gate, gate_index, next_copy + 1)); } (gate, gate_index, next_copy) @@ -909,9 +914,9 @@ impl, const D: usize> CircuitBuilder { }; if copy == NUM_U32_ARITHMETIC_OPS - 1 { - self.current_u32_arithmetic_gate = None; + self.batched_gates.current_u32_arithmetic_gate = None; } else { - self.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); + self.batched_gates.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); } (gate_index, copy) @@ -928,9 +933,9 @@ impl, const D: usize> CircuitBuilder { }; if copy == NUM_U32_SUBTRACTION_OPS - 1 { - self.current_u32_subtraction_gate = None; + self.batched_gates.current_u32_subtraction_gate = None; } else { - self.current_u32_subtraction_gate = Some((gate_index, copy + 1)); + self.batched_gates.current_u32_subtraction_gate = Some((gate_index, copy + 1)); } (gate_index, copy) From b2b7cb39317dc60c589787b1ec9e32a99d610dab Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 09:57:32 -0800 Subject: [PATCH 032/202] merge --- src/plonk/circuit_builder.rs | 194 +++++++++++++++++++++-------------- 1 file changed, 118 insertions(+), 76 deletions(-) diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 898b20a5..f21c68cb 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -537,76 +537,6 @@ impl, const D: usize> CircuitBuilder { ) } - /// Fill the remaining unused arithmetic operations with zeros, so that all - /// `ArithmeticExtensionGenerator` are run. - fn fill_arithmetic_gates(&mut self) { - let zero = self.zero_extension(); - let remaining_arithmetic_gates = self.free_arithmetic.values().copied().collect::>(); - for (gate, i) in remaining_arithmetic_gates { - for j in i..ArithmeticExtensionGate::::num_ops(&self.config) { - let wires_multiplicand_0 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_multiplicand_0(j), - ); - let wires_multiplicand_1 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_multiplicand_1(j), - ); - let wires_addend = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_addend(j), - ); - - self.connect_extension(zero, wires_multiplicand_0); - self.connect_extension(zero, wires_multiplicand_1); - self.connect_extension(zero, wires_addend); - } - } - } - - /// Fill the remaining unused random access operations with zeros, so that all - /// `RandomAccessGenerator`s are run. - fn fill_random_access_gates(&mut self) { - let zero = self.zero(); - for (vec_size, (_, i)) in self.free_random_access.clone() { - let max_copies = RandomAccessGate::::max_num_copies( - self.config.num_routed_wires, - self.config.num_wires, - vec_size, - ); - for _ in i..max_copies { - self.random_access(zero, zero, vec![zero; vec_size]); - } - } - } - - /// Fill the remaining unused switch gates with dummy values, so that all - /// `SwitchGenerator` are run. - fn fill_switch_gates(&mut self) { - let zero = self.zero(); - - for chunk_size in 1..=self.current_switch_gates.len() { - if let Some((gate, gate_index, mut copy)) = - self.current_switch_gates[chunk_size - 1].clone() - { - while copy < gate.num_copies { - for element in 0..chunk_size { - let wire_first_input = - Target::wire(gate_index, gate.wire_first_input(copy, element)); - let wire_second_input = - Target::wire(gate_index, gate.wire_second_input(copy, element)); - let wire_switch_bool = - Target::wire(gate_index, gate.wire_switch_bool(copy)); - self.connect(zero, wire_first_input); - self.connect(zero, wire_second_input); - self.connect(zero, wire_switch_bool); - } - copy += 1; - } - } - } - } - pub fn print_gate_counts(&self, min_delta: usize) { // Print gate counts for each context. self.context_log @@ -845,7 +775,7 @@ impl, const D: usize> CircuitBuilder { /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index /// `g` and the gate's `i`-th operation is available. - pub fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { + pub(crate) fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { let (gate, i) = self .batched_gates .free_arithmetic @@ -873,11 +803,18 @@ impl, const D: usize> CircuitBuilder { (gate, i) } - pub fn find_switch_gate(&mut self, chunk_size: usize) -> (SwitchGate, usize, usize) { + pub(crate) fn find_switch_gate( + &mut self, + chunk_size: usize, + ) -> (SwitchGate, usize, usize) { if self.batched_gates.current_switch_gates.len() < chunk_size { self.batched_gates.current_switch_gates.extend(vec![ None; - chunk_size - self.batched_gates.current_switch_gates.len() + chunk_size + - self + .batched_gates + .current_switch_gates + .len() ]); } @@ -897,13 +834,13 @@ impl, const D: usize> CircuitBuilder { self.batched_gates.current_switch_gates[chunk_size - 1] = None; } else { self.batched_gates.current_switch_gates[chunk_size - 1] = - Some((gate, gate_index, next_copy + 1)); + Some((gate.clone(), gate_index, next_copy + 1)); } (gate, gate_index, next_copy) } - pub fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) { + pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) { let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate { None => { let gate = U32ArithmeticGate::new(); @@ -922,7 +859,7 @@ impl, const D: usize> CircuitBuilder { (gate_index, copy) } - pub fn find_u32_subtraction_gate(&mut self) -> (usize, usize) { + pub(crate) fn find_u32_subtraction_gate(&mut self) -> (usize, usize) { let (gate_index, copy) = match self.batched_gates.current_u32_subtraction_gate { None => { let gate = U32SubtractionGate::new(); @@ -960,4 +897,109 @@ impl, const D: usize> CircuitBuilder { } (gate, instance) } + + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticExtensionGenerator`s are run. + fn fill_arithmetic_gates(&mut self) { + let zero = self.zero_extension(); + let remaining_arithmetic_gates = self + .batched_gates + .free_arithmetic + .values() + .copied() + .collect::>(); + for (gate, i) in remaining_arithmetic_gates { + for j in i..ArithmeticExtensionGate::::num_ops(&self.config) { + let wires_multiplicand_0 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_ith_multiplicand_0(j), + ); + let wires_multiplicand_1 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_ith_multiplicand_1(j), + ); + let wires_addend = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_ith_addend(j), + ); + + self.connect_extension(zero, wires_multiplicand_0); + self.connect_extension(zero, wires_multiplicand_1); + self.connect_extension(zero, wires_addend); + } + } + } + + /// Fill the remaining unused switch gates with dummy values, so that all + /// `SwitchGenerator`s are run. + fn fill_switch_gates(&mut self) { + let zero = self.zero(); + + for chunk_size in 1..=self.batched_gates.current_switch_gates.len() { + if let Some((gate, gate_index, mut copy)) = + self.batched_gates.current_switch_gates[chunk_size - 1].clone() + { + while copy < gate.num_copies { + for element in 0..chunk_size { + let wire_first_input = + Target::wire(gate_index, gate.wire_first_input(copy, element)); + let wire_second_input = + Target::wire(gate_index, gate.wire_second_input(copy, element)); + let wire_switch_bool = + Target::wire(gate_index, gate.wire_switch_bool(copy)); + self.connect(zero, wire_first_input); + self.connect(zero, wire_second_input); + self.connect(zero, wire_switch_bool); + } + copy += 1; + } + } + } + } + + /// Fill the remaining unused U32 arithmetic operations with zeros, so that all + /// `U32ArithmeticGenerator`s are run. + fn fill_u32_arithmetic_gates(&mut self) { + let zero = self.zero(); + if let Some((gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate { + for i in copy..NUM_U32_ARITHMETIC_OPS { + let wire_multiplicand_0 = Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_multiplicand_0(i), + ); + let wire_multiplicand_1 = Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_multiplicand_1(i), + ); + let wire_addend = + Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(i)); + + self.connect(zero, wire_multiplicand_0); + self.connect(zero, wire_multiplicand_1); + self.connect(zero, wire_addend); + } + } + } + + /// Fill the remaining unused U32 subtraction operations with zeros, so that all + /// `U32SubtractionGenerator`s are run. + fn fill_u32_subtraction_gates(&mut self) { + let zero = self.zero(); + if let Some((gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate { + for i in copy..NUM_U32_ARITHMETIC_OPS { + let wire_input_x = + Target::wire(gate_index, U32SubtractionGate::::wire_ith_input_x(i)); + let wire_input_y = + Target::wire(gate_index, U32SubtractionGate::::wire_ith_input_y(i)); + let wire_input_borrow = Target::wire( + gate_index, + U32SubtractionGate::::wire_ith_input_borrow(i), + ); + + self.connect(zero, wire_input_x); + self.connect(zero, wire_input_y); + self.connect(zero, wire_input_borrow); + } + } + } } From a4eac25f3d98a18779df1bede85fb85a33677555 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 14 Oct 2021 15:00:29 -0700 Subject: [PATCH 033/202] nonnative add reduction, and nonnative subtraction --- src/gadgets/nonnative.rs | 57 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 9762bd34..2217af74 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -30,7 +30,7 @@ impl, const D: usize> CircuitBuilder { .collect() } - // Add two `ForeignFieldTarget`s, which we assume are both normalized. + // Add two `ForeignFieldTarget`s. pub fn add_nonnative( &mut self, a: ForeignFieldTarget, @@ -42,8 +42,9 @@ impl, const D: usize> CircuitBuilder { let mut combined_limbs = self.add_virtual_u32_targets(num_limbs + 1); let mut carry = self.zero_u32(); for i in 0..num_limbs { - let (new_limb, carry) = + let (new_limb, new_carry) = self.add_three_u32(carry.clone(), a.limbs[i].clone(), b.limbs[i].clone()); + carry = new_carry; combined_limbs[i] = new_limb; } combined_limbs[num_limbs] = carry; @@ -55,10 +56,60 @@ impl, const D: usize> CircuitBuilder { } } + /// Reduces the result of a non-native addition. pub fn reduce_add_result(&mut self, limbs: Vec) -> Vec { - todo!() + let num_limbs = limbs.len(); + + let mut modulus_limbs = self.order_u32_limbs::(); + modulus_limbs.append(self.zero_u32()); + + let needs_reduce = self.list_le(modulus, limbs); + + let mut to_subtract = vec![]; + for i in 0..num_limbs { + let (low, _high) = self.mul_u32(modulus_limbs[i], needs_reduce); + to_subtract.append(low); + } + + let mut reduced_limbs = vec![]; + + let mut borrow = self.zero_u32(); + for i in 0..num_limbs { + let (result, new_borrow) = self.sub_u32(limbs[i], to_subtract[i], borrow); + reduced_limbs[i] = result; + borrow = new_borrow; + } + // Borrow should be zero here. + + reduced_limbs } + // Subtract two `ForeignFieldTarget`s. We assume that the first is larger than the second. + pub fn sub_nonnative( + &mut self, + a: ForeignFieldTarget, + b: ForeignFieldTarget, + ) -> ForeignFieldTarget { + let num_limbs = a.limbs.len(); + debug_assert!(b.limbs.len() == num_limbs); + + let mut result_limbs = vec![]; + + let mut borrow = self.zero_u32(); + for i in 0..num_limbs { + let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); + reduced_limbs[i] = result; + borrow = new_borrow; + } + // Borrow should be zero here. + + ForeignFieldTarget { + limbs: result_limbs, + _phantom: PhantomData, + } + } + + pub fn mul_nonnative( &mut self, a: ForeignFieldTarget, From 956b34c2e960ff45c719e0e09589e053131c8236 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 15 Oct 2021 12:12:09 -0700 Subject: [PATCH 034/202] add_many_u32 --- src/gadgets/arithmetic_u32.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index 91c658b1..e53c1761 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -7,7 +7,7 @@ use crate::gates::subtraction_u32::U32SubtractionGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; -#[derive(Clone)] +#[derive(Clone, Copy)] pub struct U32Target(pub Target); impl, const D: usize> CircuitBuilder { @@ -87,6 +87,25 @@ impl, const D: usize> CircuitBuilder { (final_low, combined_carry) } + pub fn add_many_u32(&mut self, to_add: Vec) -> (U32Target, U32Target) { + match to_add.len() { + 0 => (self.zero_u32(), self.zero_u32()), + 1 => (to_add[0], self.zero_u32()), + 2 => self.add_u32(to_add[0], to_add[1]), + 3 => self.add_three_u32(to_add[0], to_add[1], to_add[2]), + _ => { + let (mut low, mut carry) = self.add_u32(to_add[0], to_add[1]); + for i in 2..to_add.len() { + let (new_low, new_carry) = self.add_u32(to_add[i], low); + let (combined_carry, _zero) = self.add_u32(carry, new_carry); + low = new_low; + carry = combined_carry; + } + (low, carry) + } + } + } + pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { let zero = self.zero_u32(); self.mul_add_u32(a, b, zero) From 72aea53d138409375818551218ed8fb002beff51 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 15 Oct 2021 12:54:01 -0700 Subject: [PATCH 035/202] mul --- src/gadgets/nonnative.rs | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 2217af74..7faa6d4e 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -118,21 +118,32 @@ impl, const D: usize> CircuitBuilder { let num_limbs = a.limbs.len(); debug_assert!(b.limbs.len() == num_limbs); - /*let mut combined_limbs = self.add_virtual_u32_targets(2 * num_limbs - 1); + let mut combined_limbs = self.add_virtual_u32_targets(2 * num_limbs - 1); + let mut to_add = vec![vec![]; 2 * num_limbs]; for i in 0..num_limbs { for j in 0..num_limbs { - let sum = self.add_u32(a.limbs[i], b.limbs[j]); - combined_limbs[i + j] = self.add_u32(combined_limbs[i + j], sum); + let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); + to_add[i + j].push(product); + to_add[i + j + 1].push(carry); } } + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for i in 0..2 * num_limbs { + to_add[i].push(carry); + let (new_result, new_carry) = self.add_many_u32(to_add[i]); + combined_limbs.push(new_result); + carry = new_carry; + } + combined_limbs.push(carry); + let reduced_limbs = self.reduce_mul_result::(combined_limbs); ForeignFieldTarget { limbs: reduced_limbs, _phantom: PhantomData, - }*/ - todo!() + } } pub fn reduce_mul_result(&mut self, limbs: Vec) -> Vec { From 9077c7fa3c517ffd0086e982585732b22aef3f09 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 15 Oct 2021 16:47:29 -0700 Subject: [PATCH 036/202] BigUint arithmetic, and cleanup --- src/gadgets/arithmetic_u32.rs | 4 +- src/gadgets/biguint.rs | 82 ++++++++++++++++++++++++++++++ src/gadgets/mod.rs | 3 +- src/gadgets/multiple_comparison.rs | 12 ++--- src/gadgets/permutation.rs | 3 +- src/gates/subtraction_u32.rs | 64 +++++++++++------------ 6 files changed, 119 insertions(+), 49 deletions(-) create mode 100644 src/gadgets/biguint.rs diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index e53c1761..ce35d4f0 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -1,8 +1,6 @@ -use std::marker::PhantomData; - use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; -use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; +use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::subtraction_u32::U32SubtractionGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs new file mode 100644 index 00000000..4adc2382 --- /dev/null +++ b/src/gadgets/biguint.rs @@ -0,0 +1,82 @@ +use std::marker::PhantomData; + +use num::BigUint; + +use crate::field::field_types::RichField; +use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::plonk::circuit_builder::CircuitBuilder; + +pub struct BigUintTarget { + limbs: Vec, +} + +impl, const D: usize> CircuitBuilder { + // Add two `BigUintTarget`s. + pub fn add_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + let num_limbs = a.limbs.len(); + debug_assert!(b.limbs.len() == num_limbs); + + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for i in 0..num_limbs { + let (new_limb, new_carry) = + self.add_three_u32(carry.clone(), a.limbs[i].clone(), b.limbs[i].clone()); + carry = new_carry; + combined_limbs.push(new_limb); + } + combined_limbs[num_limbs] = carry; + + BigUintTarget { + limbs: combined_limbs, + } + } + + // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. + pub fn sub_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + let num_limbs = a.limbs.len(); + debug_assert!(b.limbs.len() == num_limbs); + + let mut result_limbs = vec![]; + + let mut borrow = self.zero_u32(); + for i in 0..num_limbs { + let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); + result_limbs[i] = result; + borrow = new_borrow; + } + // Borrow should be zero here. + + BigUintTarget { + limbs: result_limbs, + } + } + + pub fn mul_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + let num_limbs = a.limbs.len(); + debug_assert!(b.limbs.len() == num_limbs); + + let mut to_add = vec![vec![]; 2 * num_limbs]; + for i in 0..num_limbs { + for j in 0..num_limbs { + let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); + to_add[i + j].push(product); + to_add[i + j + 1].push(carry); + } + } + + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for i in 0..2 * num_limbs { + to_add[i].push(carry); + let (new_result, new_carry) = self.add_many_u32(to_add[i].clone()); + combined_limbs.push(new_result); + carry = new_carry; + } + combined_limbs.push(carry); + + BigUintTarget { + limbs: combined_limbs, + } + } +} diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 9fb572c9..cf6f6ed4 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -1,11 +1,12 @@ pub mod arithmetic; pub mod arithmetic_extension; pub mod arithmetic_u32; +pub mod biguint; pub mod hash; pub mod insert; pub mod interpolation; pub mod multiple_comparison; -pub mod nonnative; +//pub mod nonnative; pub mod permutation; pub mod polynomial; pub mod random_access; diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index 5291323d..11225ca5 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -1,19 +1,13 @@ -use std::marker::PhantomData; - -use itertools::izip; - use crate::field::extension_field::Extendable; -use crate::field::field_types::{Field, RichField}; +use crate::field::field_types::RichField; use crate::gates::comparison::ComparisonGate; -use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; -use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; impl, const D: usize> CircuitBuilder { /// Returns true if a is less than or equal to b, considered as limbs of a large value. - pub fn compare_lists(&mut self, a: Vec, b: Vec, num_bits: usize) -> BoolTarget { + pub fn list_le(&mut self, a: Vec, b: Vec, num_bits: usize) -> BoolTarget { assert_eq!( a.len(), b.len(), @@ -22,7 +16,7 @@ impl, const D: usize> CircuitBuilder { let n = a.len(); let chunk_size = 4; - let num_chunks = ceil_div_usize(num_bits, 4); + let num_chunks = ceil_div_usize(num_bits, chunk_size); let one = self.one(); let mut result = self.one(); diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index a0c9b087..ae0e411b 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -73,8 +73,7 @@ impl, const D: usize> CircuitBuilder { let chunk_size = a1.len(); - let (gate, gate_index, mut next_copy) = - self.find_switch_gate(chunk_size); + let (gate, gate_index, mut next_copy) = self.find_switch_gate(chunk_size); let num_copies = gate.num_copies; diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs index de79b67f..fc2009be 100644 --- a/src/gates/subtraction_u32.rs +++ b/src/gates/subtraction_u32.rs @@ -13,7 +13,7 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Number of arithmetic operations performed by an arithmetic gate. +/// Maximum number of subtractions operations performed by a single gate. pub const NUM_U32_SUBTRACTION_OPS: usize = 3; /// A gate to perform a subtraction . @@ -28,7 +28,7 @@ impl, const D: usize> U32SubtractionGate { _phantom: PhantomData, } } - + pub fn wire_ith_input_x(i: usize) -> usize { debug_assert!(i < NUM_U32_SUBTRACTION_OPS); 5 * i @@ -168,7 +168,8 @@ impl, const D: usize> Gate for U32Subtraction // Range-check output_result to be at most 32 bits. let mut combined_limbs = builder.zero_extension(); - let limb_base = builder.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); + let limb_base = builder + .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); for j in (0..Self::num_limbs()).rev() { let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); @@ -245,15 +246,15 @@ impl, const D: usize> SimpleGenerator let local_target = |input| Target::wire(self.gate_index, input); let mut deps = Vec::with_capacity(3); - deps.push(local_target( - U32SubtractionGate::::wire_ith_input_x(self.i), - )); - deps.push(local_target( - U32SubtractionGate::::wire_ith_input_y(self.i), - )); - deps.push(local_target(U32SubtractionGate::::wire_ith_input_borrow( + deps.push(local_target(U32SubtractionGate::::wire_ith_input_x( self.i, ))); + deps.push(local_target(U32SubtractionGate::::wire_ith_input_y( + self.i, + ))); + deps.push(local_target( + U32SubtractionGate::::wire_ith_input_borrow(self.i), + )); deps } @@ -265,11 +266,10 @@ impl, const D: usize> SimpleGenerator let get_local_wire = |input| witness.get_wire(local_wire(input)); - let input_x = - get_local_wire(U32SubtractionGate::::wire_ith_input_x(self.i)); - let input_y = - get_local_wire(U32SubtractionGate::::wire_ith_input_y(self.i)); - let input_borrow = get_local_wire(U32SubtractionGate::::wire_ith_input_borrow(self.i)); + let input_x = get_local_wire(U32SubtractionGate::::wire_ith_input_x(self.i)); + let input_y = get_local_wire(U32SubtractionGate::::wire_ith_input_y(self.i)); + let input_borrow = + get_local_wire(U32SubtractionGate::::wire_ith_input_borrow(self.i)); let result_initial = input_x - input_y - input_borrow; let result_initial_u64 = result_initial.to_canonical_u64(); @@ -281,7 +281,7 @@ impl, const D: usize> SimpleGenerator let base = F::from_canonical_u64(1 << 32u64); let output_result = result_initial + base * output_borrow; - + let output_result_wire = local_wire(U32SubtractionGate::::wire_ith_output_result(self.i)); let output_borrow_wire = @@ -295,12 +295,12 @@ impl, const D: usize> SimpleGenerator let num_limbs = U32SubtractionGate::::num_limbs(); let limb_base = 1 << U32SubtractionGate::::limb_bits(); let output_limbs: Vec<_> = (0..num_limbs) - .scan(output_result_u64, |acc, _| { - let tmp = *acc % limb_base; - *acc /= limb_base; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); + .scan(output_result_u64, |acc, _| { + let tmp = *acc % limb_base; + *acc /= limb_base; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); for j in 0..num_limbs { let wire = local_wire(U32SubtractionGate::::wire_ith_output_jth_limb( @@ -321,9 +321,9 @@ mod tests { use crate::field::crandall_field::CrandallField; use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::{Field, PrimeField}; - use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; use crate::hash::hash_types::HashOut; use crate::plonk::vars::EvaluationVars; @@ -347,11 +347,7 @@ mod tests { type FF = QuarticExtension; const D: usize = 4; - fn get_wires( - inputs_x: Vec, - inputs_y: Vec, - borrows: Vec, - ) -> Vec { + fn get_wires(inputs_x: Vec, inputs_y: Vec, borrows: Vec) -> Vec { let mut v0 = Vec::new(); let mut v1 = Vec::new(); @@ -377,12 +373,12 @@ mod tests { let output_result_u64 = output_result.to_canonical_u64(); let mut output_limbs: Vec<_> = (0..num_limbs) - .scan(output_result_u64, |acc, _| { - let tmp = *acc % limb_base; - *acc /= limb_base; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); + .scan(output_result_u64, |acc, _| { + let tmp = *acc % limb_base; + *acc /= limb_base; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); v0.push(input_x); v0.push(input_y); From b567cf9bafdc4f3983af88b69f753dedc650ed17 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 18 Oct 2021 15:04:54 -0700 Subject: [PATCH 037/202] some more BigUint arithmetic --- src/gadgets/arithmetic_u32.rs | 10 ++- src/gadgets/biguint.rs | 111 ++++++++++++++++++++++++++++- src/gadgets/multiple_comparison.rs | 2 +- src/gadgets/nonnative.rs | 24 ++++--- src/iop/witness.rs | 18 ++++- src/plonk/circuit_builder.rs | 5 ++ 6 files changed, 155 insertions(+), 15 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index ce35d4f0..0d88d52c 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -5,7 +5,7 @@ use crate::gates::subtraction_u32::U32SubtractionGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub struct U32Target(pub Target); impl, const D: usize> CircuitBuilder { @@ -28,6 +28,14 @@ impl, const D: usize> CircuitBuilder { U32Target(self.one()) } + pub fn connect_u32(&mut self, x: U32Target, y: U32Target) { + self.connect(x.0, y.0) + } + + pub fn assert_zero_u32(&self, x: U32Target) { + self.assert_zero(x.0) + } + // Returns x * y + z. pub fn mul_add_u32( &mut self, diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 4adc2382..b30b4680 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -1,17 +1,81 @@ use std::marker::PhantomData; +use std::ops::Neg; use num::BigUint; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gadgets::arithmetic_u32::U32Target; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; +use crate::iop::target::{BoolTarget, Target}; +use crate::iop::witness::PartitionWitness; use crate::plonk::circuit_builder::CircuitBuilder; +#[derive(Clone, Debug)] pub struct BigUintTarget { - limbs: Vec, + pub limbs: Vec, +} + +impl BigUintTarget { + pub fn num_limbs(&self) -> usize { + self.limbs.len() + } + + pub fn get_limb(&self, i: usize) -> U32Target { + self.limbs[i] + } } impl, const D: usize> CircuitBuilder { + fn connect_biguint(&self, lhs: BigUintTarget, rhs: BigUintTarget) { + let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); + for i in 0..min_limbs { + self.connect_u32(lhs.get_limb(i), rhs.get_limb(i)); + } + + for i in min_limbs..lhs.num_limbs() { + self.assert_zero_u32(lhs.get_limb(i)); + } + for i in min_limbs..rhs.num_limbs() { + self.assert_zero_u32(rhs.get_limb(i)); + } + } + + fn pad_biguints(&mut self, a: BigUintTarget, b: BigUintTarget) -> (BigUintTarget, BigUintTarget) { + if a.num_limbs() > b.num_limbs() { + let mut padded_b_limbs = b.limbs.clone(); + padded_b_limbs.extend(self.add_virtual_u32_targets(a.num_limbs() - b.num_limbs())); + let padded_b = BigUintTarget { + limbs: padded_b_limbs, + }; + (a, padded_b) + } else { + let mut padded_a_limbs = a.limbs.clone(); + padded_a_limbs.extend(self.add_virtual_u32_targets(b.num_limbs() - a.num_limbs())); + let padded_a = BigUintTarget { + limbs: padded_a_limbs, + }; + (padded_a, b) + } + } + + fn cmp_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BoolTarget { + let (padded_a, padded_b) = self.pad_biguints(a, b); + + let a_vec = a.limbs.iter().map(|&x| x.0).collect(); + let b_vec = b.limbs.iter().map(|&x| x.0).collect(); + + self.list_le(a_vec, b_vec, 32) + } + + fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { + let limbs = (0..num_limbs).map(|_| self.add_virtual_u32_target()).collect(); + + BigUintTarget { + limbs, + } + } + // Add two `BigUintTarget`s. pub fn add_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { let num_limbs = a.limbs.len(); @@ -79,4 +143,49 @@ impl, const D: usize> CircuitBuilder { limbs: combined_limbs, } } + + pub fn div_rem_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> (BigUintTarget, BigUintTarget) { + let num_limbs = a.limbs.len(); + let div = self.add_virtual_biguint_target(num_limbs); + let rem = self.add_virtual_biguint_target(num_limbs); + + self.add_simple_generator(BigUintDivRemGenerator:: { + a: a.clone(), + b: b.clone(), + div: div.clone(), + rem: rem.clone(), + _phantom: PhantomData, + }); + + let div_b = self.mul_biguint(div, b); + let div_b_plus_rem = self.add_biguint(div_b, rem); + self.connect_biguint(x, div_b_plus_rem); + + let + + self.assert_one() + + (div, rem) + } } + +#[derive(Debug)] +struct BigUintDivRemGenerator, const D: usize> { + a: BigUintTarget, + b: BigUintTarget, + div: BigUintTarget, + rem: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for BigUintDivRemGenerator +{ + fn dependencies(&self) -> Vec { + self.a.limbs.iter().map(|&l| l.0).chain(self.b.limbs.iter().map(|&l| l.0)).collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + + } +} \ No newline at end of file diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index 11225ca5..3c2ac0f5 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -11,7 +11,7 @@ impl, const D: usize> CircuitBuilder { assert_eq!( a.len(), b.len(), - "Permutation must have same number of inputs and outputs" + "Comparison must be between same number of inputs and outputs" ); let n = a.len(); diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 7faa6d4e..e41f8fb5 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -1,21 +1,12 @@ -use std::collections::BTreeMap; use std::marker::PhantomData; -use num::bigint::BigUint; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gadgets::arithmetic_u32::U32Target; -use crate::gates::arithmetic_u32::U32ArithmeticGate; -use crate::gates::switch::SwitchGate; -use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::Target; -use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::util::bimap::bimap_from_lists; pub struct ForeignFieldTarget { - /// These F elements are assumed to contain 32-bit values. limbs: Vec, _phantom: PhantomData, } @@ -30,6 +21,16 @@ impl, const D: usize> CircuitBuilder { .collect() } + fn power_of_2_mod_order(&mut self, i: usize) -> Vec { + + } + + pub fn powers_of_2_mod_order(&mut self, max: usize) -> Vec> { + for i in 0..max { + + } + } + // Add two `ForeignFieldTarget`s. pub fn add_nonnative( &mut self, @@ -63,7 +64,7 @@ impl, const D: usize> CircuitBuilder { let mut modulus_limbs = self.order_u32_limbs::(); modulus_limbs.append(self.zero_u32()); - let needs_reduce = self.list_le(modulus, limbs); + let needs_reduce = self.list_le(modulus_limbs, limbs); let mut to_subtract = vec![]; for i in 0..num_limbs { @@ -98,7 +99,7 @@ impl, const D: usize> CircuitBuilder { let mut borrow = self.zero_u32(); for i in 0..num_limbs { let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); - reduced_limbs[i] = result; + result_limbs[i] = result; borrow = new_borrow; } // Borrow should be zero here. @@ -147,6 +148,7 @@ impl, const D: usize> CircuitBuilder { } pub fn reduce_mul_result(&mut self, limbs: Vec) -> Vec { + todo!() } } diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 858bacd9..12186fb1 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -1,9 +1,13 @@ use std::collections::HashMap; use std::convert::TryInto; +use num::BigUint; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field_types::Field; +use crate::field::field_types::{Field, PrimeField}; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::biguint::BigUintTarget; use crate::hash::hash_types::HashOutTarget; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; @@ -53,12 +57,24 @@ pub trait Witness { panic!("not a bool") } + fn get_u32_target(&self, target: U32Target) -> F { + let result = self.get_target(target.0); + debug_assert!(result.to_canonical_u64() < 1 << 32u64); + result + } + fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { HashOut { elements: self.get_targets(&ht.elements).try_into().unwrap(), } } + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { + let mut result = BigUint::zero(); + for (i, &limb) in target + result + } + fn get_wire(&self, wire: Wire) -> F { self.get_target(Target::Wire(wire)) } diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index f21c68cb..804e7be2 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -246,6 +246,11 @@ impl, const D: usize> CircuitBuilder { self.connect(x, zero); } + pub fn assert_one(&mut self, x: Target) { + let one = self.one(); + self.connect(x, one); + } + pub fn add_generators(&mut self, generators: Vec>>) { self.generators.extend(generators); } From e8c2813cc7771d4c7f7aa94d2a86edbfab001a76 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 18 Oct 2021 15:11:37 -0700 Subject: [PATCH 038/202] fixes and fmt --- src/gadgets/arithmetic_u32.rs | 6 +++++ src/gadgets/biguint.rs | 51 ++++++++++++++++++++++++----------- src/iop/witness.rs | 14 +--------- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index 0d88d52c..7cb5b7c5 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -36,6 +36,12 @@ impl, const D: usize> CircuitBuilder { self.assert_zero(x.0) } + fn get_u32_target(&self, target: U32Target) -> F { + let result = self.get_target(target.0); + debug_assert!(result.to_canonical_u64() < 1 << 32u64); + result + } + // Returns x * y + z. pub fn mul_add_u32( &mut self, diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index b30b4680..54589f54 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -41,7 +41,22 @@ impl, const D: usize> CircuitBuilder { } } - fn pad_biguints(&mut self, a: BigUintTarget, b: BigUintTarget) -> (BigUintTarget, BigUintTarget) { + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { + let mut result = BigUint::zero(); + let base = BigUint::from_u64(1 << 32u64); + for &limb in target.limbs.iter().rev() { + let limb_value = self.get_target(limb.0); + result += BigUint::from_u64(limb_value.to_canonical_u64()); + result *= base; + } + result + } + + fn pad_biguints( + &mut self, + a: BigUintTarget, + b: BigUintTarget, + ) -> (BigUintTarget, BigUintTarget) { if a.num_limbs() > b.num_limbs() { let mut padded_b_limbs = b.limbs.clone(); padded_b_limbs.extend(self.add_virtual_u32_targets(a.num_limbs() - b.num_limbs())); @@ -69,11 +84,11 @@ impl, const D: usize> CircuitBuilder { } fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { - let limbs = (0..num_limbs).map(|_| self.add_virtual_u32_target()).collect(); + let limbs = (0..num_limbs) + .map(|_| self.add_virtual_u32_target()) + .collect(); - BigUintTarget { - limbs, - } + BigUintTarget { limbs } } // Add two `BigUintTarget`s. @@ -144,7 +159,11 @@ impl, const D: usize> CircuitBuilder { } } - pub fn div_rem_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> (BigUintTarget, BigUintTarget) { + pub fn div_rem_biguint( + &mut self, + a: BigUintTarget, + b: BigUintTarget, + ) -> (BigUintTarget, BigUintTarget) { let num_limbs = a.limbs.len(); let div = self.add_virtual_biguint_target(num_limbs); let rem = self.add_virtual_biguint_target(num_limbs); @@ -159,11 +178,10 @@ impl, const D: usize> CircuitBuilder { let div_b = self.mul_biguint(div, b); let div_b_plus_rem = self.add_biguint(div_b, rem); - self.connect_biguint(x, div_b_plus_rem); + self.connect_biguint(a, div_b_plus_rem); - let - - self.assert_one() + let cmp_rem_b = self.cmp_biguint(rem, b); + self.assert_one(cmp_rem_b.target); (div, rem) } @@ -182,10 +200,13 @@ impl, const D: usize> SimpleGenerator for BigUintDivRemGenerator { fn dependencies(&self) -> Vec { - self.a.limbs.iter().map(|&l| l.0).chain(self.b.limbs.iter().map(|&l| l.0)).collect() + self.a + .limbs + .iter() + .map(|&l| l.0) + .chain(self.b.limbs.iter().map(|&l| l.0)) + .collect() } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - - } -} \ No newline at end of file + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) {} +} diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 12186fb1..13a374e2 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::convert::TryInto; -use num::BigUint; +use num::{BigUint, FromPrimitive, Zero}; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; @@ -57,24 +57,12 @@ pub trait Witness { panic!("not a bool") } - fn get_u32_target(&self, target: U32Target) -> F { - let result = self.get_target(target.0); - debug_assert!(result.to_canonical_u64() < 1 << 32u64); - result - } - fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { HashOut { elements: self.get_targets(&ht.elements).try_into().unwrap(), } } - fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { - let mut result = BigUint::zero(); - for (i, &limb) in target - result - } - fn get_wire(&self, wire: Wire) -> F { self.get_target(Target::Wire(wire)) } From 557456ddd91767ca184cd835606c2aa65a186e42 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 18 Oct 2021 16:04:23 -0700 Subject: [PATCH 039/202] fix --- src/gadgets/arithmetic_u32.rs | 8 +------- src/gadgets/biguint.rs | 27 ++++++++------------------- 2 files changed, 9 insertions(+), 26 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index 7cb5b7c5..db6d3669 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -32,16 +32,10 @@ impl, const D: usize> CircuitBuilder { self.connect(x.0, y.0) } - pub fn assert_zero_u32(&self, x: U32Target) { + pub fn assert_zero_u32(&mut self, x: U32Target) { self.assert_zero(x.0) } - fn get_u32_target(&self, target: U32Target) -> F { - let result = self.get_target(target.0); - debug_assert!(result.to_canonical_u64() < 1 << 32u64); - result - } - // Returns x * y + z. pub fn mul_add_u32( &mut self, diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 54589f54..82bfd91e 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use std::ops::Neg; -use num::BigUint; +use num::{BigUint, Zero}; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; @@ -27,7 +27,7 @@ impl BigUintTarget { } impl, const D: usize> CircuitBuilder { - fn connect_biguint(&self, lhs: BigUintTarget, rhs: BigUintTarget) { + fn connect_biguint(&mut self, lhs: BigUintTarget, rhs: BigUintTarget) { let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); for i in 0..min_limbs { self.connect_u32(lhs.get_limb(i), rhs.get_limb(i)); @@ -41,17 +41,6 @@ impl, const D: usize> CircuitBuilder { } } - fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { - let mut result = BigUint::zero(); - let base = BigUint::from_u64(1 << 32u64); - for &limb in target.limbs.iter().rev() { - let limb_value = self.get_target(limb.0); - result += BigUint::from_u64(limb_value.to_canonical_u64()); - result *= base; - } - result - } - fn pad_biguints( &mut self, a: BigUintTarget, @@ -75,10 +64,10 @@ impl, const D: usize> CircuitBuilder { } fn cmp_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BoolTarget { - let (padded_a, padded_b) = self.pad_biguints(a, b); + let (padded_a, padded_b) = self.pad_biguints(a.clone(), b.clone()); - let a_vec = a.limbs.iter().map(|&x| x.0).collect(); - let b_vec = b.limbs.iter().map(|&x| x.0).collect(); + let a_vec = padded_a.limbs.iter().map(|&x| x.0).collect(); + let b_vec = padded_b.limbs.iter().map(|&x| x.0).collect(); self.list_le(a_vec, b_vec, 32) } @@ -176,11 +165,11 @@ impl, const D: usize> CircuitBuilder { _phantom: PhantomData, }); - let div_b = self.mul_biguint(div, b); - let div_b_plus_rem = self.add_biguint(div_b, rem); + let div_b = self.mul_biguint(div.clone(), b.clone()); + let div_b_plus_rem = self.add_biguint(div_b, rem.clone()); self.connect_biguint(a, div_b_plus_rem); - let cmp_rem_b = self.cmp_biguint(rem, b); + let cmp_rem_b = self.cmp_biguint(rem.clone(), b); self.assert_one(cmp_rem_b.target); (div, rem) From b045afbb8abe93ee9f8838f6143867c26df4bea4 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 19 Oct 2021 13:26:48 -0700 Subject: [PATCH 040/202] biguint methods in fields, and biguint gadget progress --- src/field/extension_field/quadratic.rs | 10 +++++ src/field/extension_field/quartic.rs | 21 ++++++++++ src/field/field_types.rs | 4 ++ src/field/goldilocks_field.rs | 10 ++++- src/field/secp256k1.rs | 56 ++++++++++++-------------- src/gadgets/biguint.rs | 22 +++++++--- src/iop/generator.rs | 15 +++++++ src/iop/witness.rs | 13 ++++++ 8 files changed, 115 insertions(+), 36 deletions(-) diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index ebad5025..e2794330 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -3,6 +3,7 @@ use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use num::bigint::BigUint; +use num::Integer; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -99,6 +100,15 @@ impl> Field for QuadraticExtension { )) } + fn from_biguint(n: BigUint) -> Self { + let (high, low) = n.div_rem(&F::order()); + Self([F::from_biguint(low), F::from_biguint(high)]) + } + + fn to_biguint(&self) -> BigUint { + self.0[0].to_biguint() + F::order() * self.0[1].to_biguint() + } + fn from_canonical_u64(n: u64) -> Self { F::from_canonical_u64(n).into() } diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index 001da821..01918ff3 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -4,6 +4,7 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi use num::bigint::BigUint; use num::traits::Pow; +use num::Integer; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -104,6 +105,26 @@ impl> Field for QuarticExtension { )) } + fn from_biguint(n: BigUint) -> Self { + let (rest, first) = n.div_rem(&F::order()); + let (rest, second) = rest.div_rem(&F::order()); + let (rest, third) = rest.div_rem(&F::order()); + Self([ + F::from_biguint(first), + F::from_biguint(second), + F::from_biguint(third), + F::from_biguint(rest), + ]) + } + + fn to_biguint(&self) -> BigUint { + let mut result = self.0[3].to_biguint(); + result = result * F::order() + self.0[2].to_biguint(); + result = result * F::order() + self.0[1].to_biguint(); + result = result * F::order() + self.0[0].to_biguint(); + result + } + fn from_canonical_u64(n: u64) -> Self { F::from_canonical_u64(n).into() } diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 4fe10b17..481d87ba 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -206,6 +206,10 @@ pub trait Field: subgroup.into_iter().map(|x| x * shift).collect() } + fn from_biguint(n: BigUint) -> Self; + + fn to_biguint(&self) -> BigUint; + fn from_canonical_u64(n: u64) -> Self; fn from_canonical_u32(n: u32) -> Self { diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index 45164506..cb85d56d 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -4,7 +4,7 @@ use std::hash::{Hash, Hasher}; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use num::BigUint; +use num::{BigUint, Integer}; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -91,6 +91,14 @@ impl Field for GoldilocksField { try_inverse_u64(self) } + fn from_biguint(n: BigUint) -> Self { + Self(n.mod_floor(&Self::order()).to_u64_digits()[0]) + } + + fn to_biguint(&self) -> BigUint { + self.to_canonical_u64().into() + } + #[inline] fn from_canonical_u64(n: u64) -> Self { debug_assert!(n < Self::ORDER); diff --git a/src/field/secp256k1.rs b/src/field/secp256k1.rs index 56d506d6..5f8e1b4e 100644 --- a/src/field/secp256k1.rs +++ b/src/field/secp256k1.rs @@ -36,27 +36,6 @@ fn biguint_from_array(arr: [u64; 4]) -> BigUint { ]) } -impl Secp256K1Base { - fn to_canonical_biguint(&self) -> BigUint { - let mut result = biguint_from_array(self.0); - if result >= Self::order() { - result -= Self::order(); - } - result - } - - fn from_biguint(val: BigUint) -> Self { - Self( - val.to_u64_digits() - .into_iter() - .pad_using(4, |_| 0) - .collect::>()[..] - .try_into() - .expect("error converting to u64 array"), - ) - } -} - impl Default for Secp256K1Base { fn default() -> Self { Self::ZERO @@ -65,7 +44,7 @@ impl Default for Secp256K1Base { impl PartialEq for Secp256K1Base { fn eq(&self, other: &Self) -> bool { - self.to_canonical_biguint() == other.to_canonical_biguint() + self.to_biguint() == other.to_biguint() } } @@ -73,19 +52,19 @@ impl Eq for Secp256K1Base {} impl Hash for Secp256K1Base { fn hash(&self, state: &mut H) { - self.to_canonical_biguint().hash(state) + self.to_biguint().hash(state) } } impl Display for Secp256K1Base { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Display::fmt(&self.to_canonical_biguint(), f) + Display::fmt(&self.to_biguint(), f) } } impl Debug for Secp256K1Base { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Debug::fmt(&self.to_canonical_biguint(), f) + Debug::fmt(&self.to_biguint(), f) } } @@ -129,6 +108,25 @@ impl Field for Secp256K1Base { Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) } + fn to_biguint(&self) -> BigUint { + let mut result = biguint_from_array(self.0); + if result >= Self::order() { + result -= Self::order(); + } + result + } + + fn from_biguint(val: BigUint) -> Self { + Self( + val.to_u64_digits() + .into_iter() + .pad_using(4, |_| 0) + .collect::>()[..] + .try_into() + .expect("error converting to u64 array"), + ) + } + #[inline] fn from_canonical_u64(n: u64) -> Self { Self([n, 0, 0, 0]) @@ -157,7 +155,7 @@ impl Neg for Secp256K1Base { if self.is_zero() { Self::ZERO } else { - Self::from_biguint(Self::order() - self.to_canonical_biguint()) + Self::from_biguint(Self::order() - self.to_biguint()) } } } @@ -167,7 +165,7 @@ impl Add for Secp256K1Base { #[inline] fn add(self, rhs: Self) -> Self { - let mut result = self.to_canonical_biguint() + rhs.to_canonical_biguint(); + let mut result = self.to_biguint() + rhs.to_biguint(); if result >= Self::order() { result -= Self::order(); } @@ -210,9 +208,7 @@ impl Mul for Secp256K1Base { #[inline] fn mul(self, rhs: Self) -> Self { - Self::from_biguint( - (self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()), - ) + Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order())) } } diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 82bfd91e..9e00c562 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -1,14 +1,13 @@ use std::marker::PhantomData; -use std::ops::Neg; -use num::{BigUint, Zero}; +use num::Integer; +use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; -use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gadgets::arithmetic_u32::U32Target; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; -use crate::iop::witness::PartitionWitness; +use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; #[derive(Clone, Debug)] @@ -197,5 +196,18 @@ impl, const D: usize> SimpleGenerator .collect() } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) {} + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_biguint_target(self.a.clone()); + let b = witness.get_biguint_target(self.b.clone()); + let (div, rem) = a.div_rem(&b); + + out_buffer.set_biguint_target(self.div.clone(), div); + out_buffer.set_biguint_target(self.rem.clone(), rem); + } +} + +#[cfg(test)] +mod tests { + #[test] + fn test_biguint_add() {} } diff --git a/src/iop/generator.rs b/src/iop/generator.rs index eb2c95f7..c395ad73 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -1,9 +1,13 @@ use std::fmt::Debug; use std::marker::PhantomData; +use num::BigUint; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::{Field, RichField}; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::biguint::BigUintTarget; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -150,6 +154,17 @@ impl GeneratedValues { self.target_values.push((target, value)) } + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)) + } + + pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) { + let limbs = value.to_u32_digits(); + for i in 0..target.num_limbs() { + self.set_u32_target(target.get_limb(i), limbs[i]); + } + } + pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) { ht.elements .iter() diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 13a374e2..c1f877cb 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -57,6 +57,19 @@ pub trait Witness { panic!("not a bool") } + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { + let mut result = BigUint::zero(); + + let limb_base = BigUint::from_u64(1 << 32u64).unwrap(); + for i in (0..target.num_limbs()).rev() { + let limb = target.get_limb(i); + result *= &limb_base; + result += self.get_target(limb.0).to_biguint(); + } + + result + } + fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { HashOut { elements: self.get_targets(&ht.elements).try_into().unwrap(), From 649c2e2b524545a6e7ea67ed3191f26b391aeaa1 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 19 Oct 2021 14:39:10 -0700 Subject: [PATCH 041/202] tests for biguint gadget --- src/gadgets/biguint.rs | 46 ++++++++++++++++++++++++++++++++++++++++-- src/iop/witness.rs | 3 +-- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 9e00c562..f064a1bd 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use num::Integer; +use num::{BigUint, Integer}; use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; @@ -26,6 +26,18 @@ impl BigUintTarget { } impl, const D: usize> CircuitBuilder { + fn constant_biguint(&mut self, value: BigUint) -> BigUintTarget { + let limb_values = value.to_u32_digits(); + let mut limbs = Vec::new(); + for i in 0..limb_values.len() { + limbs.push(U32Target( + self.constant(F::from_canonical_u32(limb_values[i])), + )); + } + + BigUintTarget { limbs } + } + fn connect_biguint(&mut self, lhs: BigUintTarget, rhs: BigUintTarget) { let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); for i in 0..min_limbs { @@ -208,6 +220,36 @@ impl, const D: usize> SimpleGenerator #[cfg(test)] mod tests { + use anyhow::Result; + use num::{BigUint, FromPrimitive}; + + use crate::{ + field::crandall_field::CrandallField, + iop::witness::PartialWitness, + plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify}, + }; + #[test] - fn test_biguint_add() {} + fn test_biguint_add() -> Result<()> { + let x_value = BigUint::from_u128(22222222222222222222222222222222222222).unwrap(); + let y_value = BigUint::from_u128(33333333333333333333333333333333333333).unwrap(); + let expected_z_value = &x_value + &y_value; + + type F = CrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(x_value); + let y = builder.constant_biguint(y_value); + let z = builder.add_biguint(x, y); + let expected_z = builder.constant_biguint(expected_z_value); + + builder.connect_biguint(z, expected_z); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } } diff --git a/src/iop/witness.rs b/src/iop/witness.rs index c1f877cb..0388a6cb 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -5,8 +5,7 @@ use num::{BigUint, FromPrimitive, Zero}; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field_types::{Field, PrimeField}; -use crate::gadgets::arithmetic_u32::U32Target; +use crate::field::field_types::Field; use crate::gadgets::biguint::BigUintTarget; use crate::hash::hash_types::HashOutTarget; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; From 0c182c462116cd6a075eef16d1d437d8372c7048 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 19 Oct 2021 14:43:39 -0700 Subject: [PATCH 042/202] fix --- src/gadgets/biguint.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index f064a1bd..3aa5c8c1 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -104,7 +104,7 @@ impl, const D: usize> CircuitBuilder { carry = new_carry; combined_limbs.push(new_limb); } - combined_limbs[num_limbs] = carry; + combined_limbs.push(carry); BigUintTarget { limbs: combined_limbs, From 140279113952a96cd74042feb6f420a0d2e21458 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 09:58:14 -0800 Subject: [PATCH 043/202] merge --- src/gadgets/biguint.rs | 26 +++++++++++++++++++++++++- src/plonk/circuit_builder.rs | 11 ++++++++--- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 3aa5c8c1..80c72a1d 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -121,7 +121,7 @@ impl, const D: usize> CircuitBuilder { let mut borrow = self.zero_u32(); for i in 0..num_limbs { let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); - result_limbs[i] = result; + result_limbs.push(result); borrow = new_borrow; } // Borrow should be zero here. @@ -252,4 +252,28 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_biguint_sub() -> Result<()> { + let x_value = BigUint::from_u128(33333333333333333333333333333333333333).unwrap(); + let y_value = BigUint::from_u128(22222222222222222222222222222222222222).unwrap(); + let expected_z_value = &x_value - &y_value; + + type F = CrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(x_value); + let y = builder.constant_biguint(y_value); + let z = builder.sub_biguint(x, y); + let expected_z = builder.constant_biguint(expected_z_value); + + builder.connect_biguint(z, expected_z); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } } diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 804e7be2..dbc8f3df 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -565,9 +565,7 @@ impl, const D: usize> CircuitBuilder { let mut timing = TimingTree::new("preprocess", Level::Trace); let start = Instant::now(); - self.fill_arithmetic_gates(); - self.fill_random_access_gates(); - self.fill_switch_gates(); + self.fill_batched_gates(); // Hash the public inputs, and route them to a `PublicInputGate` which will enforce that // those hash wires match the claimed public inputs. @@ -1007,4 +1005,11 @@ impl, const D: usize> CircuitBuilder { } } } + + fn fill_batched_gates(&mut self) { + self.fill_arithmetic_gates(); + self.fill_switch_gates(); + self.fill_u32_arithmetic_gates(); + self.fill_u32_subtraction_gates(); + } } From 62519eeb124f08ca624f3211d54bc40ef0f46163 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 19 Oct 2021 16:01:41 -0700 Subject: [PATCH 044/202] biguint mul test --- src/gadgets/biguint.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 80c72a1d..e5075f5b 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -276,4 +276,28 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_biguint_mul() -> Result<()> { + let x_value = BigUint::from_u128(123123123123123123123123123123123123).unwrap(); + let y_value = BigUint::from_u128(456456456456456456456456456456456456).unwrap(); + let expected_z_value = &x_value * &y_value; + + type F = CrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(x_value); + let y = builder.constant_biguint(y_value); + let z = builder.mul_biguint(x, y); + let expected_z = builder.constant_biguint(expected_z_value); + + builder.connect_biguint(z, expected_z); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } } From 166ab77ee3a9f4eea59e6a82a3c738556a4662cb Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 20 Oct 2021 14:14:33 -0700 Subject: [PATCH 045/202] biguint_cmp test --- src/gadgets/biguint.rs | 69 +++++++++++++++++++++++++++--- src/gadgets/multiple_comparison.rs | 4 +- src/iop/generator.rs | 3 +- 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index e5075f5b..9fe47bf9 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -93,14 +93,23 @@ impl, const D: usize> CircuitBuilder { // Add two `BigUintTarget`s. pub fn add_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { - let num_limbs = a.limbs.len(); - debug_assert!(b.limbs.len() == num_limbs); + let num_limbs = a.num_limbs().max(b.num_limbs()); let mut combined_limbs = vec![]; let mut carry = self.zero_u32(); for i in 0..num_limbs { - let (new_limb, new_carry) = - self.add_three_u32(carry.clone(), a.limbs[i].clone(), b.limbs[i].clone()); + let a_limb = if i < a.num_limbs() { + a.limbs[i].clone() + } else { + self.zero_u32() + }; + let b_limb = if i < b.num_limbs() { + b.limbs[i].clone() + } else { + self.zero_u32() + }; + + let (new_limb, new_carry) = self.add_three_u32(carry.clone(), a_limb, b_limb); carry = new_carry; combined_limbs.push(new_limb); } @@ -221,7 +230,7 @@ impl, const D: usize> SimpleGenerator #[cfg(test)] mod tests { use anyhow::Result; - use num::{BigUint, FromPrimitive}; + use num::{BigUint, FromPrimitive, Integer}; use crate::{ field::crandall_field::CrandallField, @@ -300,4 +309,54 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_biguint_cmp() -> Result<()> { + let x_value = BigUint::from_u128(33333333333333333333333333333333333333).unwrap(); + let y_value = BigUint::from_u128(22222222222222222222222222222222222222).unwrap(); + + type F = CrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(x_value); + let y = builder.constant_biguint(y_value); + let cmp = builder.cmp_biguint(x, y); + let expected_cmp = builder.constant_bool(true); + + builder.connect(cmp.target, expected_cmp.target); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_biguint_div_rem() -> Result<()> { + let x_value = BigUint::from_u128(456456456456456456456456456456456456).unwrap(); + let y_value = BigUint::from_u128(123123123123123123123123123123123123).unwrap(); + let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value); + + type F = CrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(x_value); + let y = builder.constant_biguint(y_value); + let (div, rem) = builder.div_rem_biguint(x, y); + + let expected_div = builder.constant_biguint(expected_div_value); + let expected_rem = builder.constant_biguint(expected_rem_value); + + //builder.connect_biguint(div, expected_div); + //builder.connect_biguint(rem, expected_rem); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } } diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index 3c2ac0f5..bfeaa098 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -15,8 +15,8 @@ impl, const D: usize> CircuitBuilder { ); let n = a.len(); - let chunk_size = 4; - let num_chunks = ceil_div_usize(num_bits, chunk_size); + let chunk_bits = 2; + let num_chunks = ceil_div_usize(num_bits, chunk_bits); let one = self.one(); let mut result = self.one(); diff --git a/src/iop/generator.rs b/src/iop/generator.rs index c395ad73..c5c67bcb 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -159,7 +159,8 @@ impl GeneratedValues { } pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) { - let limbs = value.to_u32_digits(); + let mut limbs = value.to_u32_digits(); + limbs.resize(target.num_limbs(), 0); for i in 0..target.num_limbs() { self.set_u32_target(target.get_limb(i), limbs[i]); } From 048048cea249d437cef1e2345400d4fec2020270 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 21 Oct 2021 09:51:40 -0700 Subject: [PATCH 046/202] test for list_le --- src/gadgets/biguint.rs | 5 ---- src/gadgets/multiple_comparison.rs | 45 ++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 9fe47bf9..0f7ac993 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -258,7 +258,6 @@ mod tests { let data = builder.build(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) } @@ -282,7 +281,6 @@ mod tests { let data = builder.build(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) } @@ -306,7 +304,6 @@ mod tests { let data = builder.build(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) } @@ -329,7 +326,6 @@ mod tests { let data = builder.build(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) } @@ -356,7 +352,6 @@ mod tests { let data = builder.build(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) } } diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index bfeaa098..d9578e46 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -53,3 +53,48 @@ impl, const D: usize> CircuitBuilder { BoolTarget::new_unsafe(result) } } + +#[cfg(test)] +mod tests { + use anyhow::Result; + use rand::Rng; + + use crate::field::crandall_field::CrandallField; + use crate::field::field_types::Field; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + fn test_list_le(size: usize) -> Result<()> { + type F = CrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let mut rng = rand::thread_rng(); + + let lst1: Vec = (0..size) + .map(|_| F::from_canonical_u32(rng.gen())) + .collect(); + let lst2: Vec = (0..size) + .map(|_| F::from_canonical_u32(rng.gen())) + .collect(); + let a = lst1.iter().map(|&x| builder.constant(x)).collect(); + let b = lst2.iter().map(|&x| builder.constant(x)).collect(); + + let result = builder.list_le(a, b, 32); + + let expected_result = builder.constant_bool(true); + builder.connect(result.target, expected_result.target); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_multiple_comparison_trivial() -> Result<()> { + test_list_le(1) + } +} From 9e49c3f2b4aaaa455c23b1dd74d8b11fb6072887 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 21 Oct 2021 12:12:09 -0700 Subject: [PATCH 047/202] fix to test --- src/gadgets/multiple_comparison.rs | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index d9578e46..fab64b9e 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -74,14 +74,20 @@ mod tests { let mut rng = rand::thread_rng(); - let lst1: Vec = (0..size) - .map(|_| F::from_canonical_u32(rng.gen())) + let lst1: Vec = (0..size) + .map(|_| rng.gen()) .collect(); - let lst2: Vec = (0..size) - .map(|_| F::from_canonical_u32(rng.gen())) + let lst2: Vec = (0..size) + .map(|i| { + let mut res = rng.gen(); + while res < lst1[i] { + res = rng.gen(); + } + res + }) .collect(); - let a = lst1.iter().map(|&x| builder.constant(x)).collect(); - let b = lst2.iter().map(|&x| builder.constant(x)).collect(); + let a = lst1.iter().map(|&x| builder.constant(F::from_canonical_u32(x))).collect(); + let b = lst2.iter().map(|&x| builder.constant(F::from_canonical_u32(x))).collect(); let result = builder.list_le(a, b, 32); From 90178b2b0a5064ef55dda3ce7766a2684d7c3eb6 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 25 Oct 2021 15:49:44 -0700 Subject: [PATCH 048/202] many fixes --- src/gadgets/biguint.rs | 14 +++++--------- src/gadgets/multiple_comparison.rs | 17 ++++++++++++++++- src/gates/comparison.rs | 18 ++++++++++-------- src/plonk/circuit_builder.rs | 1 + 4 files changed, 32 insertions(+), 18 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 0f7ac993..d3e92cc7 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -190,7 +190,7 @@ impl, const D: usize> CircuitBuilder { self.connect_biguint(a, div_b_plus_rem); let cmp_rem_b = self.cmp_biguint(rem.clone(), b); - self.assert_one(cmp_rem_b.target); + self.assert_zero(cmp_rem_b.target); (div, rem) } @@ -232,11 +232,7 @@ mod tests { use anyhow::Result; use num::{BigUint, FromPrimitive, Integer}; - use crate::{ - field::crandall_field::CrandallField, - iop::witness::PartialWitness, - plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify}, - }; + use crate::{field::{crandall_field::CrandallField, field_types::PrimeField}, iop::witness::PartialWitness, plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify}}; #[test] fn test_biguint_add() -> Result<()> { @@ -320,7 +316,7 @@ mod tests { let x = builder.constant_biguint(x_value); let y = builder.constant_biguint(y_value); let cmp = builder.cmp_biguint(x, y); - let expected_cmp = builder.constant_bool(true); + let expected_cmp = builder.constant_bool(false); builder.connect(cmp.target, expected_cmp.target); @@ -347,8 +343,8 @@ mod tests { let expected_div = builder.constant_biguint(expected_div_value); let expected_rem = builder.constant_biguint(expected_rem_value); - //builder.connect_biguint(div, expected_div); - //builder.connect_biguint(rem, expected_rem); + builder.connect_biguint(div, expected_div); + builder.connect_biguint(rem, expected_rem); let data = builder.build(); let proof = data.prove(pw).unwrap(); diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index fab64b9e..7b036a33 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -80,7 +80,7 @@ mod tests { let lst2: Vec = (0..size) .map(|i| { let mut res = rng.gen(); - while res < lst1[i] { + while res <= lst1[i] { res = rng.gen(); } res @@ -103,4 +103,19 @@ mod tests { fn test_multiple_comparison_trivial() -> Result<()> { test_list_le(1) } + + #[test] + fn test_multiple_comparison_small() -> Result<()> { + test_list_le(3) + } + + #[test] + fn test_multiple_comparison_medium() -> Result<()> { + test_list_le(6) + } + + #[test] + fn test_multiple_comparison_large() -> Result<()> { + test_list_le(10) + } } diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 9783a42c..2920c072 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -119,10 +119,10 @@ impl, const D: usize> Gate for ComparisonGate for i in 0..self.num_chunks { // Range-check the chunks to be less than `chunk_size`. - let first_product = (0..chunk_size) + let first_product: F::Extension = (0..chunk_size) .map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x)) .product(); - let second_product = (0..chunk_size) + let second_product: F::Extension = (0..chunk_size) .map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x)) .product(); constraints.push(first_product); @@ -200,10 +200,10 @@ impl, const D: usize> Gate for ComparisonGate for i in 0..self.num_chunks { // Range-check the chunks to be less than `chunk_size`. - let first_product = (0..chunk_size) + let first_product: F = (0..chunk_size) .map(|x| first_chunks[i] - F::from_canonical_usize(x)) .product(); - let second_product = (0..chunk_size) + let second_product: F = (0..chunk_size) .map(|x| second_chunks[i] - F::from_canonical_usize(x)) .product(); constraints.push(first_product); @@ -447,15 +447,17 @@ impl, const D: usize> SimpleGenerator } let most_significant_diff = most_significant_diff_so_far; - let two_n_plus_msd = - (1 << self.gate.chunk_bits()) as u64 + most_significant_diff.to_canonical_u64(); - let msd_bits: Vec = (0..self.gate.chunk_bits() + 1) + let two_n = F::from_canonical_usize(1 << self.gate.chunk_bits()); + let two_n_plus_msd = (two_n + most_significant_diff).to_canonical_u64(); + + let msd_bits_u64: Vec = (0..self.gate.chunk_bits() + 1) .scan(two_n_plus_msd, |acc, _| { let tmp = *acc % 2; *acc /= 2; - Some(F::from_canonical_u64(tmp)) + Some(tmp) }) .collect(); + let msd_bits: Vec = msd_bits_u64.iter().map(|x| F::from_canonical_u64(*x)).collect(); out_buffer.set_wire(local_wire(self.gate.wire_result_bool()), result); out_buffer.set_wire( diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index dbc8f3df..9a90cac2 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -202,6 +202,7 @@ impl, const D: usize> CircuitBuilder { gate_ref, constants, }); + index } From 7e81f297f2a5c7faf17780b3365ebb2b6fdd2719 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 25 Oct 2021 15:51:03 -0700 Subject: [PATCH 049/202] another fix --- src/gadgets/biguint.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index d3e92cc7..adf0245c 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -190,7 +190,7 @@ impl, const D: usize> CircuitBuilder { self.connect_biguint(a, div_b_plus_rem); let cmp_rem_b = self.cmp_biguint(rem.clone(), b); - self.assert_zero(cmp_rem_b.target); + self.assert_one(cmp_rem_b.target); (div, rem) } From f41c8ee16f7a87f5ff5d0e6a89b383f93b89fe14 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 25 Oct 2021 17:28:59 -0700 Subject: [PATCH 050/202] fmt --- src/gadgets/biguint.rs | 6 +++++- src/gadgets/multiple_comparison.rs | 14 +++++++++----- src/gates/comparison.rs | 5 ++++- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index adf0245c..83ca5796 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -232,7 +232,11 @@ mod tests { use anyhow::Result; use num::{BigUint, FromPrimitive, Integer}; - use crate::{field::{crandall_field::CrandallField, field_types::PrimeField}, iop::witness::PartialWitness, plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify}}; + use crate::{ + field::{crandall_field::CrandallField, field_types::PrimeField}, + iop::witness::PartialWitness, + plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify}, + }; #[test] fn test_biguint_add() -> Result<()> { diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index 7b036a33..e5f70a9e 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -74,9 +74,7 @@ mod tests { let mut rng = rand::thread_rng(); - let lst1: Vec = (0..size) - .map(|_| rng.gen()) - .collect(); + let lst1: Vec = (0..size).map(|_| rng.gen()).collect(); let lst2: Vec = (0..size) .map(|i| { let mut res = rng.gen(); @@ -86,8 +84,14 @@ mod tests { res }) .collect(); - let a = lst1.iter().map(|&x| builder.constant(F::from_canonical_u32(x))).collect(); - let b = lst2.iter().map(|&x| builder.constant(F::from_canonical_u32(x))).collect(); + let a = lst1 + .iter() + .map(|&x| builder.constant(F::from_canonical_u32(x))) + .collect(); + let b = lst2 + .iter() + .map(|&x| builder.constant(F::from_canonical_u32(x))) + .collect(); let result = builder.list_le(a, b, 32); diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 2920c072..a610c5e2 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -457,7 +457,10 @@ impl, const D: usize> SimpleGenerator Some(tmp) }) .collect(); - let msd_bits: Vec = msd_bits_u64.iter().map(|x| F::from_canonical_u64(*x)).collect(); + let msd_bits: Vec = msd_bits_u64 + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect(); out_buffer.set_wire(local_wire(self.gate.wire_result_bool()), result); out_buffer.set_wire( From f639dd3359550048f92eafbea5004519e8e8da2b Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 26 Oct 2021 11:04:31 -0700 Subject: [PATCH 051/202] fixes to nonnative --- src/gadgets/mod.rs | 2 +- src/gadgets/multiple_comparison.rs | 8 ++++++++ src/gadgets/nonnative.rs | 21 ++++----------------- src/gadgets/permutation.rs | 1 - 4 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index cf6f6ed4..8b6e60f6 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -6,7 +6,7 @@ pub mod hash; pub mod insert; pub mod interpolation; pub mod multiple_comparison; -//pub mod nonnative; +pub mod nonnative; pub mod permutation; pub mod polynomial; pub mod random_access; diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index e5f70a9e..579708c8 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -1,3 +1,4 @@ +use super::arithmetic_u32::U32Target; use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; use crate::gates::comparison::ComparisonGate; @@ -52,6 +53,13 @@ impl, const D: usize> CircuitBuilder { BoolTarget::new_unsafe(result) } + + /// Helper function for comparing, specifically, lists of `U32Target`s. + pub fn list_le_u32(&mut self, a: Vec, b: Vec) -> BoolTarget { + let a_targets = a.iter().map(|&t| t.0).collect(); + let b_targets = b.iter().map(|&t| t.0).collect(); + self.list_le(a_targets, b_targets, 32) + } } #[cfg(test)] diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index e41f8fb5..406d6852 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -1,6 +1,5 @@ use std::marker::PhantomData; - use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gadgets::arithmetic_u32::U32Target; @@ -21,16 +20,6 @@ impl, const D: usize> CircuitBuilder { .collect() } - fn power_of_2_mod_order(&mut self, i: usize) -> Vec { - - } - - pub fn powers_of_2_mod_order(&mut self, max: usize) -> Vec> { - for i in 0..max { - - } - } - // Add two `ForeignFieldTarget`s. pub fn add_nonnative( &mut self, @@ -62,14 +51,14 @@ impl, const D: usize> CircuitBuilder { let num_limbs = limbs.len(); let mut modulus_limbs = self.order_u32_limbs::(); - modulus_limbs.append(self.zero_u32()); + modulus_limbs.push(self.zero_u32()); - let needs_reduce = self.list_le(modulus_limbs, limbs); + let needs_reduce = self.list_le_u32(modulus_limbs, limbs); let mut to_subtract = vec![]; for i in 0..num_limbs { - let (low, _high) = self.mul_u32(modulus_limbs[i], needs_reduce); - to_subtract.append(low); + let (low, _high) = self.mul_u32(modulus_limbs[i], U32Target(needs_reduce.target)); + to_subtract.push(low); } let mut reduced_limbs = vec![]; @@ -110,7 +99,6 @@ impl, const D: usize> CircuitBuilder { } } - pub fn mul_nonnative( &mut self, a: ForeignFieldTarget, @@ -148,7 +136,6 @@ impl, const D: usize> CircuitBuilder { } pub fn reduce_mul_result(&mut self, limbs: Vec) -> Vec { - todo!() } } diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index ae0e411b..aa06294a 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -35,7 +35,6 @@ impl, const D: usize> CircuitBuilder { self.assert_permutation_2x2(a[0].clone(), a[1].clone(), b[0].clone(), b[1].clone()) } // For larger lists, we recursively use two smaller permutation networks. - //_ => self.assert_permutation_recursive(a, b) _ => self.assert_permutation_recursive(a, b), } } From 6232aa68fb2c36a18b6ad5ceb36b1d739d0407e3 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 26 Oct 2021 12:16:25 -0700 Subject: [PATCH 052/202] fix --- src/gadgets/nonnative.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 406d6852..ff344697 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -53,7 +53,7 @@ impl, const D: usize> CircuitBuilder { let mut modulus_limbs = self.order_u32_limbs::(); modulus_limbs.push(self.zero_u32()); - let needs_reduce = self.list_le_u32(modulus_limbs, limbs); + let needs_reduce = self.list_le_u32(modulus_limbs.clone(), limbs.clone()); let mut to_subtract = vec![]; for i in 0..num_limbs { @@ -121,7 +121,7 @@ impl, const D: usize> CircuitBuilder { let mut carry = self.zero_u32(); for i in 0..2 * num_limbs { to_add[i].push(carry); - let (new_result, new_carry) = self.add_many_u32(to_add[i]); + let (new_result, new_carry) = self.add_many_u32(to_add[i].clone()); combined_limbs.push(new_result); carry = new_carry; } From 87d81290341e7350e26d530bc11b587fc9211440 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 26 Oct 2021 14:38:10 -0700 Subject: [PATCH 053/202] reduce --- src/gadgets/biguint.rs | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 83ca5796..1f4ccbb5 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -26,7 +26,7 @@ impl BigUintTarget { } impl, const D: usize> CircuitBuilder { - fn constant_biguint(&mut self, value: BigUint) -> BigUintTarget { + fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { let limb_values = value.to_u32_digits(); let mut limbs = Vec::new(); for i in 0..limb_values.len() { @@ -194,6 +194,24 @@ impl, const D: usize> CircuitBuilder { (div, rem) } + + pub fn div_biguint( + &mut self, + a: BigUintTarget, + b: BigUintTarget, + ) -> BigUintTarget { + let (div, _rem) = self.div_rem_biguint(a, b); + div + } + + pub fn rem_biguint( + &mut self, + a: BigUintTarget, + b: BigUintTarget, + ) -> BigUintTarget { + let (_div, rem) = self.div_rem_biguint(a, b); + rem + } } #[derive(Debug)] From bfe201d95155618c14a5ac935a18da46e4b5eacc Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 26 Oct 2021 14:38:18 -0700 Subject: [PATCH 054/202] fmt --- src/gadgets/biguint.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 1f4ccbb5..9b3895a7 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -195,20 +195,12 @@ impl, const D: usize> CircuitBuilder { (div, rem) } - pub fn div_biguint( - &mut self, - a: BigUintTarget, - b: BigUintTarget, - ) -> BigUintTarget { + pub fn div_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { let (div, _rem) = self.div_rem_biguint(a, b); div } - pub fn rem_biguint( - &mut self, - a: BigUintTarget, - b: BigUintTarget, - ) -> BigUintTarget { + pub fn rem_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { let (_div, rem) = self.div_rem_biguint(a, b); rem } From f7ce33b7aef73a51ab7db64fffb7b0bc2ed0710b Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 26 Oct 2021 15:56:08 -0700 Subject: [PATCH 055/202] using refs in right places; and lots of fixes --- src/gadgets/biguint.rs | 82 ++++++++++++------------ src/gadgets/nonnative.rs | 132 ++++++++++++--------------------------- 2 files changed, 80 insertions(+), 134 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 9b3895a7..1ccf9c3a 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -26,7 +26,7 @@ impl BigUintTarget { } impl, const D: usize> CircuitBuilder { - fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { + pub fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { let limb_values = value.to_u32_digits(); let mut limbs = Vec::new(); for i in 0..limb_values.len() { @@ -38,7 +38,7 @@ impl, const D: usize> CircuitBuilder { BigUintTarget { limbs } } - fn connect_biguint(&mut self, lhs: BigUintTarget, rhs: BigUintTarget) { + pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); for i in 0..min_limbs { self.connect_u32(lhs.get_limb(i), rhs.get_limb(i)); @@ -52,7 +52,7 @@ impl, const D: usize> CircuitBuilder { } } - fn pad_biguints( + pub fn pad_biguints<'a>( &mut self, a: BigUintTarget, b: BigUintTarget, @@ -74,7 +74,7 @@ impl, const D: usize> CircuitBuilder { } } - fn cmp_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BoolTarget { + pub fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { let (padded_a, padded_b) = self.pad_biguints(a.clone(), b.clone()); let a_vec = padded_a.limbs.iter().map(|&x| x.0).collect(); @@ -83,7 +83,7 @@ impl, const D: usize> CircuitBuilder { self.list_le(a_vec, b_vec, 32) } - fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { + pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { let limbs = (0..num_limbs) .map(|_| self.add_virtual_u32_target()) .collect(); @@ -92,7 +92,7 @@ impl, const D: usize> CircuitBuilder { } // Add two `BigUintTarget`s. - pub fn add_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + pub fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let num_limbs = a.num_limbs().max(b.num_limbs()); let mut combined_limbs = vec![]; @@ -121,7 +121,7 @@ impl, const D: usize> CircuitBuilder { } // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. - pub fn sub_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + pub fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let num_limbs = a.limbs.len(); debug_assert!(b.limbs.len() == num_limbs); @@ -140,7 +140,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn mul_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + pub fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let num_limbs = a.limbs.len(); debug_assert!(b.limbs.len() == num_limbs); @@ -170,8 +170,8 @@ impl, const D: usize> CircuitBuilder { pub fn div_rem_biguint( &mut self, - a: BigUintTarget, - b: BigUintTarget, + a: &BigUintTarget, + b: &BigUintTarget, ) -> (BigUintTarget, BigUintTarget) { let num_limbs = a.limbs.len(); let div = self.add_virtual_biguint_target(num_limbs); @@ -185,22 +185,22 @@ impl, const D: usize> CircuitBuilder { _phantom: PhantomData, }); - let div_b = self.mul_biguint(div.clone(), b.clone()); - let div_b_plus_rem = self.add_biguint(div_b, rem.clone()); - self.connect_biguint(a, div_b_plus_rem); + let div_b = self.mul_biguint(&div, &b); + let div_b_plus_rem = self.add_biguint(&div_b, &rem); + self.connect_biguint(&a, &div_b_plus_rem); - let cmp_rem_b = self.cmp_biguint(rem.clone(), b); + let cmp_rem_b = self.cmp_biguint(&rem, b); self.assert_one(cmp_rem_b.target); (div, rem) } - pub fn div_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + pub fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let (div, _rem) = self.div_rem_biguint(a, b); div } - pub fn rem_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + pub fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let (_div, rem) = self.div_rem_biguint(a, b); rem } @@ -259,12 +259,12 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_biguint(x_value); - let y = builder.constant_biguint(y_value); - let z = builder.add_biguint(x, y); - let expected_z = builder.constant_biguint(expected_z_value); + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let z = builder.add_biguint(&x, &y); + let expected_z = builder.constant_biguint(&expected_z_value); - builder.connect_biguint(z, expected_z); + builder.connect_biguint(&z, &expected_z); let data = builder.build(); let proof = data.prove(pw).unwrap(); @@ -282,12 +282,12 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_biguint(x_value); - let y = builder.constant_biguint(y_value); - let z = builder.sub_biguint(x, y); - let expected_z = builder.constant_biguint(expected_z_value); + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let z = builder.sub_biguint(&x, &y); + let expected_z = builder.constant_biguint(&expected_z_value); - builder.connect_biguint(z, expected_z); + builder.connect_biguint(&z, &expected_z); let data = builder.build(); let proof = data.prove(pw).unwrap(); @@ -305,12 +305,12 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_biguint(x_value); - let y = builder.constant_biguint(y_value); - let z = builder.mul_biguint(x, y); - let expected_z = builder.constant_biguint(expected_z_value); + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let z = builder.mul_biguint(&x, &y); + let expected_z = builder.constant_biguint(&expected_z_value); - builder.connect_biguint(z, expected_z); + builder.connect_biguint(&z, &expected_z); let data = builder.build(); let proof = data.prove(pw).unwrap(); @@ -327,9 +327,9 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_biguint(x_value); - let y = builder.constant_biguint(y_value); - let cmp = builder.cmp_biguint(x, y); + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let cmp = builder.cmp_biguint(&x, &y); let expected_cmp = builder.constant_bool(false); builder.connect(cmp.target, expected_cmp.target); @@ -350,15 +350,15 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_biguint(x_value); - let y = builder.constant_biguint(y_value); - let (div, rem) = builder.div_rem_biguint(x, y); + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let (div, rem) = builder.div_rem_biguint(&x, &y); - let expected_div = builder.constant_biguint(expected_div_value); - let expected_rem = builder.constant_biguint(expected_rem_value); + let expected_div = builder.constant_biguint(&expected_div_value); + let expected_rem = builder.constant_biguint(&expected_rem_value); - builder.connect_biguint(div, expected_div); - builder.connect_biguint(rem, expected_rem); + builder.connect_biguint(&div, &expected_div); + builder.connect_biguint(&rem, &expected_rem); let data = builder.build(); let proof = data.prove(pw).unwrap(); diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index ff344697..61d0ac5c 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -1,5 +1,6 @@ use std::marker::PhantomData; +use crate::gadgets::biguint::BigUintTarget; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gadgets::arithmetic_u32::U32Target; @@ -20,122 +21,67 @@ impl, const D: usize> CircuitBuilder { .collect() } - // Add two `ForeignFieldTarget`s. - pub fn add_nonnative( - &mut self, - a: ForeignFieldTarget, - b: ForeignFieldTarget, - ) -> ForeignFieldTarget { - let num_limbs = a.limbs.len(); - debug_assert!(b.limbs.len() == num_limbs); - - let mut combined_limbs = self.add_virtual_u32_targets(num_limbs + 1); - let mut carry = self.zero_u32(); - for i in 0..num_limbs { - let (new_limb, new_carry) = - self.add_three_u32(carry.clone(), a.limbs[i].clone(), b.limbs[i].clone()); - carry = new_carry; - combined_limbs[i] = new_limb; - } - combined_limbs[num_limbs] = carry; - - let reduced_limbs = self.reduce_add_result::(combined_limbs); - ForeignFieldTarget { - limbs: reduced_limbs, - _phantom: PhantomData, + pub fn ff_to_biguint(&mut self, x: &ForeignFieldTarget) -> BigUintTarget { + BigUintTarget { + limbs: x.limbs.clone(), } } - /// Reduces the result of a non-native addition. - pub fn reduce_add_result(&mut self, limbs: Vec) -> Vec { - let num_limbs = limbs.len(); + // Add two `ForeignFieldTarget`s. + pub fn add_nonnative( + &mut self, + a: &ForeignFieldTarget, + b: &ForeignFieldTarget, + ) -> ForeignFieldTarget { + let a_biguint = self.ff_to_biguint(a); + let b_biguint = self.ff_to_biguint(b); + let result = self.add_biguint(&a_biguint, &b_biguint); - let mut modulus_limbs = self.order_u32_limbs::(); - modulus_limbs.push(self.zero_u32()); - - let needs_reduce = self.list_le_u32(modulus_limbs.clone(), limbs.clone()); - - let mut to_subtract = vec![]; - for i in 0..num_limbs { - let (low, _high) = self.mul_u32(modulus_limbs[i], U32Target(needs_reduce.target)); - to_subtract.push(low); - } - - let mut reduced_limbs = vec![]; - - let mut borrow = self.zero_u32(); - for i in 0..num_limbs { - let (result, new_borrow) = self.sub_u32(limbs[i], to_subtract[i], borrow); - reduced_limbs[i] = result; - borrow = new_borrow; - } - // Borrow should be zero here. - - reduced_limbs + self.reduce(&result) } // Subtract two `ForeignFieldTarget`s. We assume that the first is larger than the second. pub fn sub_nonnative( &mut self, - a: ForeignFieldTarget, - b: ForeignFieldTarget, + a: &ForeignFieldTarget, + b: &ForeignFieldTarget, ) -> ForeignFieldTarget { - let num_limbs = a.limbs.len(); - debug_assert!(b.limbs.len() == num_limbs); + let a_biguint = self.ff_to_biguint(a); + let b_biguint = self.ff_to_biguint(b); + let result = self.sub_biguint(&a_biguint, &b_biguint); - let mut result_limbs = vec![]; - - let mut borrow = self.zero_u32(); - for i in 0..num_limbs { - let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); - result_limbs[i] = result; - borrow = new_borrow; - } - // Borrow should be zero here. - - ForeignFieldTarget { - limbs: result_limbs, - _phantom: PhantomData, - } + self.reduce(&result) } pub fn mul_nonnative( &mut self, - a: ForeignFieldTarget, - b: ForeignFieldTarget, + a: &ForeignFieldTarget, + b: &ForeignFieldTarget, ) -> ForeignFieldTarget { - let num_limbs = a.limbs.len(); - debug_assert!(b.limbs.len() == num_limbs); + let a_biguint = self.ff_to_biguint(a); + let b_biguint = self.ff_to_biguint(b); + let result = self.mul_biguint(&a_biguint, &b_biguint); - let mut combined_limbs = self.add_virtual_u32_targets(2 * num_limbs - 1); - let mut to_add = vec![vec![]; 2 * num_limbs]; - for i in 0..num_limbs { - for j in 0..num_limbs { - let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); - to_add[i + j].push(product); - to_add[i + j + 1].push(carry); - } - } + self.reduce(&result) + } - let mut combined_limbs = vec![]; - let mut carry = self.zero_u32(); - for i in 0..2 * num_limbs { - to_add[i].push(carry); - let (new_result, new_carry) = self.add_many_u32(to_add[i].clone()); - combined_limbs.push(new_result); - carry = new_carry; - } - combined_limbs.push(carry); - - let reduced_limbs = self.reduce_mul_result::(combined_limbs); + /// Returns `x % |FF|` as a `ForeignFieldTarget`. + fn reduce( + &mut self, + x: &BigUintTarget, + ) -> ForeignFieldTarget { + let modulus = FF::order(); + let order_target = self.constant_biguint(&modulus); + let value = self.rem_biguint(x, &order_target); ForeignFieldTarget { - limbs: reduced_limbs, + limbs: value.limbs, _phantom: PhantomData, } } - pub fn reduce_mul_result(&mut self, limbs: Vec) -> Vec { - todo!() + fn reduce_ff(&mut self, x: &ForeignFieldTarget) -> ForeignFieldTarget { + let x_biguint = self.ff_to_biguint(x); + self.reduce(&x_biguint) } } From ee5619b847b6c48f6ba3372d99ce458134791c5d Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 26 Oct 2021 15:56:15 -0700 Subject: [PATCH 056/202] fmt --- src/gadgets/nonnative.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 61d0ac5c..16a62776 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -1,9 +1,9 @@ use std::marker::PhantomData; -use crate::gadgets::biguint::BigUintTarget; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::biguint::BigUintTarget; use crate::plonk::circuit_builder::CircuitBuilder; pub struct ForeignFieldTarget { @@ -66,10 +66,7 @@ impl, const D: usize> CircuitBuilder { } /// Returns `x % |FF|` as a `ForeignFieldTarget`. - fn reduce( - &mut self, - x: &BigUintTarget, - ) -> ForeignFieldTarget { + fn reduce(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { let modulus = FF::order(); let order_target = self.constant_biguint(&modulus); let value = self.rem_biguint(x, &order_target); From bbcda969e5beddb0bae1baaeafff941cde792c02 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 27 Oct 2021 11:46:16 -0700 Subject: [PATCH 057/202] nonnative tests --- src/gadgets/biguint.rs | 15 ++++----- src/gadgets/nonnative.rs | 69 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 8 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 1ccf9c3a..2c34d044 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -52,7 +52,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn pad_biguints<'a>( + pub fn pad_biguints( &mut self, a: BigUintTarget, b: BigUintTarget, @@ -141,12 +141,11 @@ impl, const D: usize> CircuitBuilder { } pub fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { - let num_limbs = a.limbs.len(); - debug_assert!(b.limbs.len() == num_limbs); + let total_limbs = a.limbs.len() + b.limbs.len(); - let mut to_add = vec![vec![]; 2 * num_limbs]; - for i in 0..num_limbs { - for j in 0..num_limbs { + let mut to_add = vec![vec![]; total_limbs]; + for i in 0..a.limbs.len() { + for j in 0..b.limbs.len() { let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); to_add[i + j].push(product); to_add[i + j + 1].push(carry); @@ -155,7 +154,7 @@ impl, const D: usize> CircuitBuilder { let mut combined_limbs = vec![]; let mut carry = self.zero_u32(); - for i in 0..2 * num_limbs { + for i in 0..total_limbs { to_add[i].push(carry); let (new_result, new_carry) = self.add_many_u32(to_add[i].clone()); combined_limbs.push(new_result); @@ -243,7 +242,7 @@ mod tests { use num::{BigUint, FromPrimitive, Integer}; use crate::{ - field::{crandall_field::CrandallField, field_types::PrimeField}, + field::crandall_field::CrandallField, iop::witness::PartialWitness, plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify}, }; diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 16a62776..c8539629 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -21,12 +21,44 @@ impl, const D: usize> CircuitBuilder { .collect() } + pub fn biguint_to_ff(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { + ForeignFieldTarget { + limbs: x.limbs.clone(), + _phantom: PhantomData, + } + } + pub fn ff_to_biguint(&mut self, x: &ForeignFieldTarget) -> BigUintTarget { BigUintTarget { limbs: x.limbs.clone(), } } + pub fn constant_ff(&mut self, x: FF) -> ForeignFieldTarget { + let x_biguint = self.constant_biguint(&x.to_biguint()); + self.biguint_to_ff(&x_biguint) + } + + // Assert that two ForeignFieldTarget's, both assumed to be in reduced form, are equal. + pub fn connect_ff_reduced( + &mut self, + lhs: &ForeignFieldTarget, + rhs: &ForeignFieldTarget, + ) { + let min_limbs = lhs.limbs.len().min(rhs.limbs.len()); + + for i in 0..min_limbs { + self.connect_u32(lhs.limbs[i], rhs.limbs[i]); + } + + for i in min_limbs..lhs.limbs.len() { + self.assert_zero_u32(lhs.limbs[i]); + } + for i in min_limbs..rhs.limbs.len() { + self.assert_zero_u32(rhs.limbs[i]); + } + } + // Add two `ForeignFieldTarget`s. pub fn add_nonnative( &mut self, @@ -82,3 +114,40 @@ impl, const D: usize> CircuitBuilder { self.reduce(&x_biguint) } } + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::crandall_field::CrandallField; + use crate::field::field_types::Field; + use crate::field::secp256k1::Secp256K1Base; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + #[test] + fn test_nonnative_add() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let y_ff = FF::rand(); + let sum_ff = x_ff + y_ff; + + type F = CrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_ff(x_ff); + let y = builder.constant_ff(y_ff); + let sum = builder.add_nonnative(&x, &y); + + let sum_expected = builder.constant_ff(sum_ff); + builder.connect_ff_reduced(&sum, &sum_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } +} From 4c5f2383fea0e96ff2eba7cc13eccb282a031872 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 27 Oct 2021 12:03:17 -0700 Subject: [PATCH 058/202] fixes to tests --- src/gadgets/biguint.rs | 12 ++++++++++-- src/gadgets/nonnative.rs | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 2c34d044..81880eef 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -59,14 +59,22 @@ impl, const D: usize> CircuitBuilder { ) -> (BigUintTarget, BigUintTarget) { if a.num_limbs() > b.num_limbs() { let mut padded_b_limbs = b.limbs.clone(); - padded_b_limbs.extend(self.add_virtual_u32_targets(a.num_limbs() - b.num_limbs())); + let to_extend = a.num_limbs() - b.num_limbs(); + for i in 0..to_extend { + padded_b_limbs.push(self.zero_u32()); + } + let padded_b = BigUintTarget { limbs: padded_b_limbs, }; (a, padded_b) } else { let mut padded_a_limbs = a.limbs.clone(); - padded_a_limbs.extend(self.add_virtual_u32_targets(b.num_limbs() - a.num_limbs())); + let to_extend = b.num_limbs() - a.num_limbs(); + for i in 0..to_extend { + padded_a_limbs.push(self.zero_u32()); + } + let padded_a = BigUintTarget { limbs: padded_a_limbs, }; diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index c8539629..6e57dc7d 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -143,8 +143,8 @@ mod tests { let y = builder.constant_ff(y_ff); let sum = builder.add_nonnative(&x, &y); - let sum_expected = builder.constant_ff(sum_ff); - builder.connect_ff_reduced(&sum, &sum_expected); + //let sum_expected = builder.constant_ff(sum_ff); + //builder.connect_ff_reduced(&sum, &sum_expected); let data = builder.build(); let proof = data.prove(pw).unwrap(); From 8f8d03951b33b9614560dc6dddb3ea171476fd0f Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 27 Oct 2021 12:03:51 -0700 Subject: [PATCH 059/202] uncomment --- src/gadgets/nonnative.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 6e57dc7d..c8539629 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -143,8 +143,8 @@ mod tests { let y = builder.constant_ff(y_ff); let sum = builder.add_nonnative(&x, &y); - //let sum_expected = builder.constant_ff(sum_ff); - //builder.connect_ff_reduced(&sum, &sum_expected); + let sum_expected = builder.constant_ff(sum_ff); + builder.connect_ff_reduced(&sum, &sum_expected); let data = builder.build(); let proof = data.prove(pw).unwrap(); From 72134a3eb0b0114667c9d5a1aac5757011e9ba1c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 27 Oct 2021 12:09:41 -0700 Subject: [PATCH 060/202] mul test --- src/gadgets/nonnative.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index c8539629..f081d382 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -150,4 +150,28 @@ mod tests { let proof = data.prove(pw).unwrap(); verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_nonnative_mul() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let y_ff = FF::rand(); + let product_ff = x_ff * y_ff; + + type F = CrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_ff(x_ff); + let y = builder.constant_ff(y_ff); + let product = builder.mul_nonnative(&x, &y); + + let product_expected = builder.constant_ff(product_ff); + builder.connect_ff_reduced(&product, &product_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } } From c664eba3e6eb0d0418d648406fd22bcf007bfce2 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 27 Oct 2021 14:58:41 -0700 Subject: [PATCH 061/202] sub test --- src/gadgets/nonnative.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index f081d382..161a0e85 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -151,6 +151,30 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + #[test] + fn test_nonnative_sub() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let y_ff = FF::rand(); + let diff_ff = x_ff - y_ff; + + type F = CrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_ff(x_ff); + let y = builder.constant_ff(y_ff); + let diff = builder.sub_nonnative(&x, &y); + + let diff_expected = builder.constant_ff(diff_ff); + builder.connect_ff_reduced(&diff, &diff_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + #[test] fn test_nonnative_mul() -> Result<()> { type FF = Secp256K1Base; From 2d9f8d97199fe02de3f0f117af08af7071360b87 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 29 Oct 2021 16:03:58 -0700 Subject: [PATCH 062/202] fix --- src/gadgets/permutation.rs | 4 +--- src/plonk/circuit_builder.rs | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index aa06294a..c60eda7d 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -72,9 +72,7 @@ impl, const D: usize> CircuitBuilder { let chunk_size = a1.len(); - let (gate, gate_index, mut next_copy) = self.find_switch_gate(chunk_size); - - let num_copies = gate.num_copies; + let (gate, gate_index, next_copy) = self.find_switch_gate(chunk_size); let mut c = Vec::new(); let mut d = Vec::new(); diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 9a90cac2..3296061c 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -822,7 +822,7 @@ impl, const D: usize> CircuitBuilder { ]); } - let (gate, gate_index, mut next_copy) = + let (gate, gate_index, next_copy) = match self.batched_gates.current_switch_gates[chunk_size - 1].clone() { None => { let gate = SwitchGate::::new_from_config(self.config.clone(), chunk_size); @@ -834,7 +834,7 @@ impl, const D: usize> CircuitBuilder { let num_copies = gate.num_copies; - if next_copy == num_copies { + if next_copy == num_copies - 1 { self.batched_gates.current_switch_gates[chunk_size - 1] = None; } else { self.batched_gates.current_switch_gates[chunk_size - 1] = From 244543578bf70aab1cb2fdd3f81d81b810f16f4f Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 1 Nov 2021 09:44:20 -0700 Subject: [PATCH 063/202] fixes to subtraction tests, and documentation --- src/gates/subtraction_u32.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs index fc2009be..c11c6f0f 100644 --- a/src/gates/subtraction_u32.rs +++ b/src/gates/subtraction_u32.rs @@ -1,7 +1,5 @@ use std::marker::PhantomData; -use itertools::unfold; - use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; @@ -16,7 +14,8 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// Maximum number of subtractions operations performed by a single gate. pub const NUM_U32_SUBTRACTION_OPS: usize = 3; -/// A gate to perform a subtraction . +/// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns +/// the result `x - y - borrow` and, if this underflows, a new `borrow`. #[derive(Clone, Debug)] pub struct U32SubtractionGate, const D: usize> { _phantom: PhantomData, @@ -395,14 +394,14 @@ mod tests { } let mut rng = rand::thread_rng(); - let inputs_x: Vec<_> = (0..NUM_U32_SUBTRACTION_OPS) + let inputs_x = (0..NUM_U32_SUBTRACTION_OPS) .map(|_| rng.gen::() as u64) .collect(); - let inputs_y: Vec<_> = (0..NUM_U32_SUBTRACTION_OPS) + let inputs_y = (0..NUM_U32_SUBTRACTION_OPS) .map(|_| rng.gen::() as u64) .collect(); - let borrows: Vec<_> = (0..NUM_U32_SUBTRACTION_OPS) - .map(|_| rng.gen::() as u64) + let borrows = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| (rng.gen::() % 2) as u64) .collect(); let gate = U32SubtractionGate:: { From 5dd4ed3e1c5c63137a7994942c09cc31132d310f Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 1 Nov 2021 11:12:21 -0700 Subject: [PATCH 064/202] addressed comments --- src/gadgets/arithmetic_u32.rs | 15 +-------------- src/gadgets/biguint.rs | 4 ++-- src/gates/subtraction_u32.rs | 2 +- 3 files changed, 4 insertions(+), 17 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index db6d3669..ba076a8f 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -81,24 +81,11 @@ impl, const D: usize> CircuitBuilder { self.mul_add_u32(a, one, b) } - pub fn add_three_u32( - &mut self, - a: U32Target, - b: U32Target, - c: U32Target, - ) -> (U32Target, U32Target) { - let (init_low, carry1) = self.add_u32(a, b); - let (final_low, carry2) = self.add_u32(c, init_low); - let (combined_carry, _zero) = self.add_u32(carry1, carry2); - (final_low, combined_carry) - } - - pub fn add_many_u32(&mut self, to_add: Vec) -> (U32Target, U32Target) { + pub fn add_many_u32(&mut self, to_add: &[U32Target]) -> (U32Target, U32Target) { match to_add.len() { 0 => (self.zero_u32(), self.zero_u32()), 1 => (to_add[0], self.zero_u32()), 2 => self.add_u32(to_add[0], to_add[1]), - 3 => self.add_three_u32(to_add[0], to_add[1], to_add[2]), _ => { let (mut low, mut carry) = self.add_u32(to_add[0], to_add[1]); for i in 2..to_add.len() { diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 81880eef..a524a79b 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -117,7 +117,7 @@ impl, const D: usize> CircuitBuilder { self.zero_u32() }; - let (new_limb, new_carry) = self.add_three_u32(carry.clone(), a_limb, b_limb); + let (new_limb, new_carry) = self.add_many_u32(&[carry.clone(), a_limb, b_limb]); carry = new_carry; combined_limbs.push(new_limb); } @@ -164,7 +164,7 @@ impl, const D: usize> CircuitBuilder { let mut carry = self.zero_u32(); for i in 0..total_limbs { to_add[i].push(carry); - let (new_result, new_carry) = self.add_many_u32(to_add[i].clone()); + let (new_result, new_carry) = self.add_many_u32(&to_add[i].clone()); combined_limbs.push(new_result); carry = new_carry; } diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs index c11c6f0f..afac85be 100644 --- a/src/gates/subtraction_u32.rs +++ b/src/gates/subtraction_u32.rs @@ -15,7 +15,7 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; pub const NUM_U32_SUBTRACTION_OPS: usize = 3; /// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns -/// the result `x - y - borrow` and, if this underflows, a new `borrow`. +/// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked. #[derive(Clone, Debug)] pub struct U32SubtractionGate, const D: usize> { _phantom: PhantomData, From 6705d81fbdba21ef9745115865af0cb412894bbe Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 1 Nov 2021 11:18:26 -0700 Subject: [PATCH 065/202] nit --- src/gates/subtraction_u32.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs index afac85be..b348b6f2 100644 --- a/src/gates/subtraction_u32.rs +++ b/src/gates/subtraction_u32.rs @@ -211,7 +211,7 @@ impl, const D: usize> Gate for U32Subtraction ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { From a3d957fa42d51c50d703691b2c44ba96b06605bf Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 1 Nov 2021 14:40:14 -0700 Subject: [PATCH 066/202] addressed comment: more tests for multiple_comparison --- src/gadgets/multiple_comparison.rs | 39 ++++++++++++------------------ src/gates/subtraction_u32.rs | 2 +- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index 579708c8..4c928c11 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -74,7 +74,7 @@ mod tests { use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::verifier::verify; - fn test_list_le(size: usize) -> Result<()> { + fn test_list_le(size: usize, num_bits: usize) -> Result<()> { type F = CrandallField; let config = CircuitConfig::large_config(); let pw = PartialWitness::new(); @@ -82,26 +82,26 @@ mod tests { let mut rng = rand::thread_rng(); - let lst1: Vec = (0..size).map(|_| rng.gen()).collect(); - let lst2: Vec = (0..size) + let lst1: Vec = (0..size).map(|_| rng.gen_range(0..(1 << num_bits))).collect(); + let lst2: Vec = (0..size) .map(|i| { - let mut res = rng.gen(); + let mut res = rng.gen_range(0..(1 << num_bits)); while res <= lst1[i] { - res = rng.gen(); + res = rng.gen_range(0..(1 << num_bits)); } res }) .collect(); let a = lst1 .iter() - .map(|&x| builder.constant(F::from_canonical_u32(x))) + .map(|&x| builder.constant(F::from_canonical_u64(x))) .collect(); let b = lst2 .iter() - .map(|&x| builder.constant(F::from_canonical_u32(x))) + .map(|&x| builder.constant(F::from_canonical_u64(x))) .collect(); - let result = builder.list_le(a, b, 32); + let result = builder.list_le(a, b, num_bits); let expected_result = builder.constant_bool(true); builder.connect(result.target, expected_result.target); @@ -112,22 +112,13 @@ mod tests { } #[test] - fn test_multiple_comparison_trivial() -> Result<()> { - test_list_le(1) - } + fn test_multiple_comparison() -> Result<()> { + for size in [1, 3, 6, 10] { + for num_bits in [20, 32, 40, 50] { + test_list_le(size, num_bits).unwrap(); + } + } - #[test] - fn test_multiple_comparison_small() -> Result<()> { - test_list_le(3) - } - - #[test] - fn test_multiple_comparison_medium() -> Result<()> { - test_list_le(6) - } - - #[test] - fn test_multiple_comparison_large() -> Result<()> { - test_list_le(10) + Ok(()) } } diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs index b348b6f2..14447727 100644 --- a/src/gates/subtraction_u32.rs +++ b/src/gates/subtraction_u32.rs @@ -50,10 +50,10 @@ impl, const D: usize> U32SubtractionGate { 5 * i + 4 } - // We have limbs ony for the first half of the output. pub fn limb_bits() -> usize { 2 } + // We have limbs for the 32 bits of `output_result`. pub fn num_limbs() -> usize { 32 / Self::limb_bits() } From bd0164c7efd9d0f55dc70fb996e15d66b624a031 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 1 Nov 2021 14:40:21 -0700 Subject: [PATCH 067/202] fmt --- src/gadgets/multiple_comparison.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index 4c928c11..e7b9ae8e 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -82,7 +82,9 @@ mod tests { let mut rng = rand::thread_rng(); - let lst1: Vec = (0..size).map(|_| rng.gen_range(0..(1 << num_bits))).collect(); + let lst1: Vec = (0..size) + .map(|_| rng.gen_range(0..(1 << num_bits))) + .collect(); let lst2: Vec = (0..size) .map(|i| { let mut res = rng.gen_range(0..(1 << num_bits)); From 237a1fad1d02fd65d0e4a079afbd136b90616fa0 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 1 Nov 2021 15:22:21 -0700 Subject: [PATCH 068/202] addressed comments --- src/gadgets/nonnative.rs | 5 ++++- src/iop/generator.rs | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 161a0e85..086e2c83 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -155,7 +155,10 @@ mod tests { fn test_nonnative_sub() -> Result<()> { type FF = Secp256K1Base; let x_ff = FF::rand(); - let y_ff = FF::rand(); + let mut y_ff = FF::rand(); + while y_ff.to_biguint() > x_ff.to_biguint() { + y_ff = FF::rand(); + } let diff_ff = x_ff - y_ff; type F = CrandallField; diff --git a/src/iop/generator.rs b/src/iop/generator.rs index c5c67bcb..8c6cb294 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -160,6 +160,8 @@ impl GeneratedValues { pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) { let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); for i in 0..target.num_limbs() { self.set_u32_target(target.get_limb(i), limbs[i]); From 6ab01e51f3da79d8bcc524905674123d5f5c91a5 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 1 Nov 2021 16:02:46 -0700 Subject: [PATCH 069/202] u32 arithmetic check for special cases --- src/gadgets/arithmetic_u32.rs | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index ba076a8f..22957075 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -36,6 +36,37 @@ impl, const D: usize> CircuitBuilder { self.assert_zero(x.0) } + /// Checks for special cases where the value of + /// `x * y + z` + /// can be determined without adding a `U32ArithmeticGate`. + pub fn arithmetic_u32_special_cases( + &mut self, + x: U32Target, + y: U32Target, + z: U32Target, + ) -> Option<(U32Target, U32Target)> { + let x_const = self.target_as_constant(x.0); + let y_const = self.target_as_constant(y.0); + let z_const = self.target_as_constant(z.0); + + // If both terms are constant, return their (constant) sum. + let first_term_const = if let (Some(xx), Some(yy)) = (x_const, y_const) { + Some(xx * yy) + } else { + None + }; + + if let (Some(a), Some(b)) = (first_term_const, z_const) { + let sum_u64 = (a + b).to_canonical_u64(); + let (low_u64, high_u64) = (sum_u64 % (1u64 << 32), sum_u64 >> 32); + let low = F::from_canonical_u64(low_u64); + let high = F::from_canonical_u64(high_u64); + return Some((self.constant_u32(low), self.constant_u32(high))); + } + + None + } + // Returns x * y + z. pub fn mul_add_u32( &mut self, @@ -43,6 +74,10 @@ impl, const D: usize> CircuitBuilder { y: U32Target, z: U32Target, ) -> (U32Target, U32Target) { + if let Some(result) = self.arithmetic_u32_special_cases(x, y, z) { + return result; + } + let (gate_index, copy) = self.find_u32_arithmetic_gate(); self.connect( From 1d4bb3950da7a81d1211e16c9b35ef4888414976 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 1 Nov 2021 16:12:43 -0700 Subject: [PATCH 070/202] FFTarget uses BigUintTarget --- src/gadgets/nonnative.rs | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 086e2c83..86d886fc 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -7,7 +7,7 @@ use crate::gadgets::biguint::BigUintTarget; use crate::plonk::circuit_builder::CircuitBuilder; pub struct ForeignFieldTarget { - limbs: Vec, + value: BigUintTarget, _phantom: PhantomData, } @@ -23,15 +23,13 @@ impl, const D: usize> CircuitBuilder { pub fn biguint_to_ff(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { ForeignFieldTarget { - limbs: x.limbs.clone(), + value: x.clone(), _phantom: PhantomData, } } pub fn ff_to_biguint(&mut self, x: &ForeignFieldTarget) -> BigUintTarget { - BigUintTarget { - limbs: x.limbs.clone(), - } + x.value.clone() } pub fn constant_ff(&mut self, x: FF) -> ForeignFieldTarget { @@ -45,18 +43,7 @@ impl, const D: usize> CircuitBuilder { lhs: &ForeignFieldTarget, rhs: &ForeignFieldTarget, ) { - let min_limbs = lhs.limbs.len().min(rhs.limbs.len()); - - for i in 0..min_limbs { - self.connect_u32(lhs.limbs[i], rhs.limbs[i]); - } - - for i in min_limbs..lhs.limbs.len() { - self.assert_zero_u32(lhs.limbs[i]); - } - for i in min_limbs..rhs.limbs.len() { - self.assert_zero_u32(rhs.limbs[i]); - } + self.connect_biguint(&lhs.value, &rhs.value); } // Add two `ForeignFieldTarget`s. @@ -104,7 +91,7 @@ impl, const D: usize> CircuitBuilder { let value = self.rem_biguint(x, &order_target); ForeignFieldTarget { - limbs: value.limbs, + value, _phantom: PhantomData, } } From e838096940492dbeeedc1bf172bae2bf2cf1c9b0 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 2 Nov 2021 10:54:34 -0700 Subject: [PATCH 071/202] use map; and TODOs --- src/field/field_types.rs | 1 + src/gadgets/biguint.rs | 10 ++++------ src/gates/assert_le.rs | 2 ++ 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 481d87ba..f5d06fdb 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -206,6 +206,7 @@ pub trait Field: subgroup.into_iter().map(|x| x * shift).collect() } + // TODO: move these to a new `PrimeField` trait (for all prime fields, not just 64-bit ones) fn from_biguint(n: BigUint) -> Self; fn to_biguint(&self) -> BigUint; diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index a524a79b..8369bfc4 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -28,12 +28,10 @@ impl BigUintTarget { impl, const D: usize> CircuitBuilder { pub fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { let limb_values = value.to_u32_digits(); - let mut limbs = Vec::new(); - for i in 0..limb_values.len() { - limbs.push(U32Target( - self.constant(F::from_canonical_u32(limb_values[i])), - )); - } + let limbs = limb_values + .iter() + .map(|l| self.constant(F::from_canonical_u32(l))) + .collect(); BigUintTarget { limbs } } diff --git a/src/gates/assert_le.rs b/src/gates/assert_le.rs index 46432c03..98411ef2 100644 --- a/src/gates/assert_le.rs +++ b/src/gates/assert_le.rs @@ -13,6 +13,8 @@ use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recu use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::util::{bits_u64, ceil_div_usize}; +// TODO: replace/merge this gate with `ComparisonGate`. + /// A gate for checking that one value is less than or equal to another. #[derive(Clone, Debug)] pub struct AssertLessThanGate, const D: usize> { From c861c10a5be55d1a814bbb598fa99710e5ad0e63 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 2 Nov 2021 12:42:42 -0700 Subject: [PATCH 072/202] nonnative neg --- src/gadgets/biguint.rs | 2 +- src/gadgets/nonnative.rs | 45 +++++++++++++++++++++++++++++++--------- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 8369bfc4..a4b9a2d0 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -30,7 +30,7 @@ impl, const D: usize> CircuitBuilder { let limb_values = value.to_u32_digits(); let limbs = limb_values .iter() - .map(|l| self.constant(F::from_canonical_u32(l))) + .map(|&l| U32Target(self.constant(F::from_canonical_u32(l)))) .collect(); BigUintTarget { limbs } diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 86d886fc..0b02b6a8 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -1,8 +1,9 @@ use std::marker::PhantomData; +use num::{BigUint, One}; + use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; -use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::biguint::BigUintTarget; use crate::plonk::circuit_builder::CircuitBuilder; @@ -12,15 +13,6 @@ pub struct ForeignFieldTarget { } impl, const D: usize> CircuitBuilder { - pub fn order_u32_limbs(&mut self) -> Vec { - let modulus = FF::order(); - let limbs = modulus.to_u32_digits(); - limbs - .iter() - .map(|&limb| self.constant_u32(F::from_canonical_u32(limb))) - .collect() - } - pub fn biguint_to_ff(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { ForeignFieldTarget { value: x.clone(), @@ -84,6 +76,17 @@ impl, const D: usize> CircuitBuilder { self.reduce(&result) } + pub fn neg_nonnative( + &mut self, + x: &ForeignFieldTarget, + ) -> ForeignFieldTarget { + let neg_one = FF::order() - BigUint::one(); + let neg_one_target = self.constant_biguint(&neg_one); + let neg_one_ff = self.biguint_to_ff(&neg_one_target); + + self.mul_nonnative(&neg_one_ff, x) + } + /// Returns `x % |FF|` as a `ForeignFieldTarget`. fn reduce(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { let modulus = FF::order(); @@ -188,4 +191,26 @@ mod tests { let proof = data.prove(pw).unwrap(); verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_nonnative_neg() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let neg_x_ff = -x_ff; + + type F = CrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_ff(x_ff); + let neg_x = builder.neg_nonnative(&x); + + let neg_x_expected = builder.constant_ff(neg_x_ff); + builder.connect_ff_reduced(&neg_x, &neg_x_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } } From cf3b6df0e4683b780eb3f96c3237f2926246c293 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 9 Nov 2021 16:36:29 -0800 Subject: [PATCH 073/202] addressed nits --- src/gadgets/arithmetic_u32.rs | 6 ++---- src/plonk/circuit_builder.rs | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index 22957075..ce7aa121 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -57,10 +57,8 @@ impl, const D: usize> CircuitBuilder { }; if let (Some(a), Some(b)) = (first_term_const, z_const) { - let sum_u64 = (a + b).to_canonical_u64(); - let (low_u64, high_u64) = (sum_u64 % (1u64 << 32), sum_u64 >> 32); - let low = F::from_canonical_u64(low_u64); - let high = F::from_canonical_u64(high_u64); + let sum = (a + b).to_canonical_u64(); + let (low, high) = (sum as u32, (sum >> 32) as u32); return Some((self.constant_u32(low), self.constant_u32(high))); } diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 3296061c..33709e32 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -318,8 +318,8 @@ impl, const D: usize> CircuitBuilder { } /// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits. - pub fn constant_u32(&mut self, c: F) -> U32Target { - U32Target(self.constant(c)) + pub fn constant_u32(&mut self, c: u32) -> U32Target { + U32Target(self.constant(F::from_canonical_u32(c))) } /// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns From 656f052b796b0a61b5974bbaa8ce02123d726f7b Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 9 Nov 2021 17:18:04 -0800 Subject: [PATCH 074/202] addressed nits --- src/gadgets/biguint.rs | 42 ++++++++++++++---------------------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index a4b9a2d0..66448e7c 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -28,10 +28,7 @@ impl BigUintTarget { impl, const D: usize> CircuitBuilder { pub fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { let limb_values = value.to_u32_digits(); - let limbs = limb_values - .iter() - .map(|&l| U32Target(self.constant(F::from_canonical_u32(l)))) - .collect(); + let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect(); BigUintTarget { limbs } } @@ -56,26 +53,19 @@ impl, const D: usize> CircuitBuilder { b: BigUintTarget, ) -> (BigUintTarget, BigUintTarget) { if a.num_limbs() > b.num_limbs() { - let mut padded_b_limbs = b.limbs.clone(); - let to_extend = a.num_limbs() - b.num_limbs(); - for i in 0..to_extend { - padded_b_limbs.push(self.zero_u32()); + let mut padded_b = b.clone(); + for _ in b.num_limbs()..a.num_limbs() { + padded_b.limbs.push(self.zero_u32()); } - let padded_b = BigUintTarget { - limbs: padded_b_limbs, - }; (a, padded_b) } else { - let mut padded_a_limbs = a.limbs.clone(); + let mut padded_a = a.clone(); let to_extend = b.num_limbs() - a.num_limbs(); - for i in 0..to_extend { - padded_a_limbs.push(self.zero_u32()); + for _ in a.num_limbs()..b.num_limbs() { + padded_a.limbs.push(self.zero_u32()); } - let padded_a = BigUintTarget { - limbs: padded_a_limbs, - }; (padded_a, b) } } @@ -104,18 +94,14 @@ impl, const D: usize> CircuitBuilder { let mut combined_limbs = vec![]; let mut carry = self.zero_u32(); for i in 0..num_limbs { - let a_limb = if i < a.num_limbs() { - a.limbs[i].clone() - } else { - self.zero_u32() - }; - let b_limb = if i < b.num_limbs() { - b.limbs[i].clone() - } else { - self.zero_u32() - }; + let a_limb = (i < a.num_limbs()) + .then(|| a.limbs[i]) + .unwrap_or_else(|| self.zero_u32()); + let b_limb = (i < b.num_limbs()) + .then(|| b.limbs[i]) + .unwrap_or_else(|| self.zero_u32()); - let (new_limb, new_carry) = self.add_many_u32(&[carry.clone(), a_limb, b_limb]); + let (new_limb, new_carry) = self.add_many_u32(&[carry, a_limb, b_limb]); carry = new_carry; combined_limbs.push(new_limb); } From db31b9f6621dc95ae13b4dda771ef7265644dca4 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 9 Nov 2021 17:21:16 -0800 Subject: [PATCH 075/202] sub_nonnative fix --- src/gadgets/nonnative.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 0b02b6a8..31f06b81 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -51,16 +51,19 @@ impl, const D: usize> CircuitBuilder { self.reduce(&result) } - // Subtract two `ForeignFieldTarget`s. We assume that the first is larger than the second. + // Subtract two `ForeignFieldTarget`s. pub fn sub_nonnative( &mut self, a: &ForeignFieldTarget, b: &ForeignFieldTarget, ) -> ForeignFieldTarget { - let a_biguint = self.ff_to_biguint(a); - let b_biguint = self.ff_to_biguint(b); - let result = self.sub_biguint(&a_biguint, &b_biguint); + let order = self.constant_biguint(&FF::order()); + let a_biguint = self.nonnative_to_biguint(a); + let a_plus_order = self.add_biguint(&order, &a_biguint); + let b_biguint = self.nonnative_to_biguint(b); + let result = self.sub_biguint(&a_plus_order, &b_biguint); + // TODO: reduce sub result with only one conditional addition? self.reduce(&result) } From 616479689818ea52b346dda00a959a4caaf059c9 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 9 Nov 2021 17:25:28 -0800 Subject: [PATCH 076/202] rename --- src/gadgets/biguint.rs | 16 ++++++------ src/gadgets/nonnative.rs | 54 ++++++++++++++++++++-------------------- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 66448e7c..f5283718 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -49,8 +49,8 @@ impl, const D: usize> CircuitBuilder { pub fn pad_biguints( &mut self, - a: BigUintTarget, - b: BigUintTarget, + a: &BigUintTarget, + b: &BigUintTarget, ) -> (BigUintTarget, BigUintTarget) { if a.num_limbs() > b.num_limbs() { let mut padded_b = b.clone(); @@ -58,7 +58,7 @@ impl, const D: usize> CircuitBuilder { padded_b.limbs.push(self.zero_u32()); } - (a, padded_b) + (a.clone(), padded_b) } else { let mut padded_a = a.clone(); let to_extend = b.num_limbs() - a.num_limbs(); @@ -66,15 +66,15 @@ impl, const D: usize> CircuitBuilder { padded_a.limbs.push(self.zero_u32()); } - (padded_a, b) + (padded_a, b.clone()) } } pub fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { - let (padded_a, padded_b) = self.pad_biguints(a.clone(), b.clone()); + let (a, b) = self.pad_biguints(a, b); - let a_vec = padded_a.limbs.iter().map(|&x| x.0).collect(); - let b_vec = padded_b.limbs.iter().map(|&x| x.0).collect(); + let a_vec = a.limbs.iter().map(|&x| x.0).collect(); + let b_vec = b.limbs.iter().map(|&x| x.0).collect(); self.list_le(a_vec, b_vec, 32) } @@ -115,7 +115,7 @@ impl, const D: usize> CircuitBuilder { // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. pub fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let num_limbs = a.limbs.len(); - debug_assert!(b.limbs.len() == num_limbs); + let (a, b) = self.pad_biguints(a, b); let mut result_limbs = vec![]; diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 31f06b81..9e7a9585 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -13,24 +13,24 @@ pub struct ForeignFieldTarget { } impl, const D: usize> CircuitBuilder { - pub fn biguint_to_ff(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { + pub fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { ForeignFieldTarget { value: x.clone(), _phantom: PhantomData, } } - pub fn ff_to_biguint(&mut self, x: &ForeignFieldTarget) -> BigUintTarget { + pub fn nonnative_to_biguint(&mut self, x: &ForeignFieldTarget) -> BigUintTarget { x.value.clone() } - pub fn constant_ff(&mut self, x: FF) -> ForeignFieldTarget { + pub fn constant_nonnative(&mut self, x: FF) -> ForeignFieldTarget { let x_biguint = self.constant_biguint(&x.to_biguint()); - self.biguint_to_ff(&x_biguint) + self.biguint_to_nonnative(&x_biguint) } // Assert that two ForeignFieldTarget's, both assumed to be in reduced form, are equal. - pub fn connect_ff_reduced( + pub fn connect_nonnative( &mut self, lhs: &ForeignFieldTarget, rhs: &ForeignFieldTarget, @@ -44,8 +44,8 @@ impl, const D: usize> CircuitBuilder { a: &ForeignFieldTarget, b: &ForeignFieldTarget, ) -> ForeignFieldTarget { - let a_biguint = self.ff_to_biguint(a); - let b_biguint = self.ff_to_biguint(b); + let a_biguint = self.nonnative_to_biguint(a); + let b_biguint = self.nonnative_to_biguint(b); let result = self.add_biguint(&a_biguint, &b_biguint); self.reduce(&result) @@ -72,8 +72,8 @@ impl, const D: usize> CircuitBuilder { a: &ForeignFieldTarget, b: &ForeignFieldTarget, ) -> ForeignFieldTarget { - let a_biguint = self.ff_to_biguint(a); - let b_biguint = self.ff_to_biguint(b); + let a_biguint = self.nonnative_to_biguint(a); + let b_biguint = self.nonnative_to_biguint(b); let result = self.mul_biguint(&a_biguint, &b_biguint); self.reduce(&result) @@ -85,7 +85,7 @@ impl, const D: usize> CircuitBuilder { ) -> ForeignFieldTarget { let neg_one = FF::order() - BigUint::one(); let neg_one_target = self.constant_biguint(&neg_one); - let neg_one_ff = self.biguint_to_ff(&neg_one_target); + let neg_one_ff = self.biguint_to_nonnative(&neg_one_target); self.mul_nonnative(&neg_one_ff, x) } @@ -102,8 +102,8 @@ impl, const D: usize> CircuitBuilder { } } - fn reduce_ff(&mut self, x: &ForeignFieldTarget) -> ForeignFieldTarget { - let x_biguint = self.ff_to_biguint(x); + fn reduce_nonnative(&mut self, x: &ForeignFieldTarget) -> ForeignFieldTarget { + let x_biguint = self.nonnative_to_biguint(x); self.reduce(&x_biguint) } } @@ -132,12 +132,12 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_ff(x_ff); - let y = builder.constant_ff(y_ff); + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); let sum = builder.add_nonnative(&x, &y); - let sum_expected = builder.constant_ff(sum_ff); - builder.connect_ff_reduced(&sum, &sum_expected); + let sum_expected = builder.constant_nonnative(sum_ff); + builder.connect_nonnative(&sum, &sum_expected); let data = builder.build(); let proof = data.prove(pw).unwrap(); @@ -159,12 +159,12 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_ff(x_ff); - let y = builder.constant_ff(y_ff); + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); let diff = builder.sub_nonnative(&x, &y); - let diff_expected = builder.constant_ff(diff_ff); - builder.connect_ff_reduced(&diff, &diff_expected); + let diff_expected = builder.constant_nonnative(diff_ff); + builder.connect_nonnative(&diff, &diff_expected); let data = builder.build(); let proof = data.prove(pw).unwrap(); @@ -183,12 +183,12 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_ff(x_ff); - let y = builder.constant_ff(y_ff); + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); let product = builder.mul_nonnative(&x, &y); - let product_expected = builder.constant_ff(product_ff); - builder.connect_ff_reduced(&product, &product_expected); + let product_expected = builder.constant_nonnative(product_ff); + builder.connect_nonnative(&product, &product_expected); let data = builder.build(); let proof = data.prove(pw).unwrap(); @@ -206,11 +206,11 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_ff(x_ff); + let x = builder.constant_nonnative(x_ff); let neg_x = builder.neg_nonnative(&x); - let neg_x_expected = builder.constant_ff(neg_x_ff); - builder.connect_ff_reduced(&neg_x, &neg_x_expected); + let neg_x_expected = builder.constant_nonnative(neg_x_ff); + builder.connect_nonnative(&neg_x, &neg_x_expected); let data = builder.build(); let proof = data.prove(pw).unwrap(); From 3f619c704cab70737cab77d52d0f344afc4621fa Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 9 Nov 2021 17:49:34 -0800 Subject: [PATCH 077/202] made test_list_le random --- src/gadgets/biguint.rs | 7 ++----- src/gadgets/multiple_comparison.rs | 22 ++++++++++++---------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index f5283718..bf2be941 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -73,10 +73,7 @@ impl, const D: usize> CircuitBuilder { pub fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { let (a, b) = self.pad_biguints(a, b); - let a_vec = a.limbs.iter().map(|&x| x.0).collect(); - let b_vec = b.limbs.iter().map(|&x| x.0).collect(); - - self.list_le(a_vec, b_vec, 32) + self.list_le_u32(a.limbs, b.limbs) } pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { @@ -213,8 +210,8 @@ impl, const D: usize> SimpleGenerator self.a .limbs .iter() + .chain(&self.b.limbs) .map(|&l| l.0) - .chain(self.b.limbs.iter().map(|&l| l.0)) .collect() } diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index e7b9ae8e..6871c1b8 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -7,7 +7,8 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; impl, const D: usize> CircuitBuilder { - /// Returns true if a is less than or equal to b, considered as limbs of a large value. + /// Returns true if a is less than or equal to b, considered as base-`2^num_bits` limbs of a large value. + /// This range-checks its inputs. pub fn list_le(&mut self, a: Vec, b: Vec, num_bits: usize) -> BoolTarget { assert_eq!( a.len(), @@ -20,7 +21,7 @@ impl, const D: usize> CircuitBuilder { let num_chunks = ceil_div_usize(num_bits, chunk_bits); let one = self.one(); - let mut result = self.one(); + let mut result = one; for i in 0..n { let a_le_b_gate = ComparisonGate::new(num_bits, num_chunks); let a_le_b_gate_index = self.add_gate(a_le_b_gate.clone(), vec![]); @@ -51,6 +52,8 @@ impl, const D: usize> CircuitBuilder { result = self.mul_add(these_limbs_equal, result, these_limbs_less_than); } + // `result` being boolean is an invariant, maintained because its new value is always + // `x * result + y`, where `x` and `y` are booleans that are not simultaneously true. BoolTarget::new_unsafe(result) } @@ -65,6 +68,7 @@ impl, const D: usize> CircuitBuilder { #[cfg(test)] mod tests { use anyhow::Result; + use num::BigUint; use rand::Rng; use crate::field::crandall_field::CrandallField; @@ -86,14 +90,12 @@ mod tests { .map(|_| rng.gen_range(0..(1 << num_bits))) .collect(); let lst2: Vec = (0..size) - .map(|i| { - let mut res = rng.gen_range(0..(1 << num_bits)); - while res <= lst1[i] { - res = rng.gen_range(0..(1 << num_bits)); - } - res - }) + .map(|_| rng.gen_range(0..(1 << num_bits))) .collect(); + + let a_biguint = BigUint::from_slice(&lst1.iter().flat_map(|&x| [x as u32, (x >> 32) as u32]).collect::>()); + let b_biguint = BigUint::from_slice(&lst2.iter().flat_map(|&x| [x as u32, (x >> 32) as u32]).collect::>()); + let a = lst1 .iter() .map(|&x| builder.constant(F::from_canonical_u64(x))) @@ -105,7 +107,7 @@ mod tests { let result = builder.list_le(a, b, num_bits); - let expected_result = builder.constant_bool(true); + let expected_result = builder.constant_bool(a_biguint <= b_biguint); builder.connect(result.target, expected_result.target); let data = builder.build(); From 7336aa091734a21394746a406d57fc83c3fdc5d6 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 9 Nov 2021 17:49:41 -0800 Subject: [PATCH 078/202] fmt --- src/gadgets/multiple_comparison.rs | 14 ++++++++++++-- src/gadgets/nonnative.rs | 5 ++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index 6871c1b8..553fb49e 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -93,8 +93,18 @@ mod tests { .map(|_| rng.gen_range(0..(1 << num_bits))) .collect(); - let a_biguint = BigUint::from_slice(&lst1.iter().flat_map(|&x| [x as u32, (x >> 32) as u32]).collect::>()); - let b_biguint = BigUint::from_slice(&lst2.iter().flat_map(|&x| [x as u32, (x >> 32) as u32]).collect::>()); + let a_biguint = BigUint::from_slice( + &lst1 + .iter() + .flat_map(|&x| [x as u32, (x >> 32) as u32]) + .collect::>(), + ); + let b_biguint = BigUint::from_slice( + &lst2 + .iter() + .flat_map(|&x| [x as u32, (x >> 32) as u32]) + .collect::>(), + ); let a = lst1 .iter() diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 9e7a9585..1d14f708 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -102,7 +102,10 @@ impl, const D: usize> CircuitBuilder { } } - fn reduce_nonnative(&mut self, x: &ForeignFieldTarget) -> ForeignFieldTarget { + fn reduce_nonnative( + &mut self, + x: &ForeignFieldTarget, + ) -> ForeignFieldTarget { let x_biguint = self.nonnative_to_biguint(x); self.reduce(&x_biguint) } From 270521a17d762fd89e2eda9061c389fdc0be2006 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 9 Nov 2021 18:08:19 -0800 Subject: [PATCH 079/202] addressed comments --- src/gadgets/biguint.rs | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index bf2be941..ad61c8a0 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -229,6 +229,7 @@ impl, const D: usize> SimpleGenerator mod tests { use anyhow::Result; use num::{BigUint, FromPrimitive, Integer}; + use rand::Rng; use crate::{ field::crandall_field::CrandallField, @@ -238,8 +239,10 @@ mod tests { #[test] fn test_biguint_add() -> Result<()> { - let x_value = BigUint::from_u128(22222222222222222222222222222222222222).unwrap(); - let y_value = BigUint::from_u128(33333333333333333333333333333333333333).unwrap(); + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); let expected_z_value = &x_value + &y_value; type F = CrandallField; @@ -261,8 +264,13 @@ mod tests { #[test] fn test_biguint_sub() -> Result<()> { - let x_value = BigUint::from_u128(33333333333333333333333333333333333333).unwrap(); - let y_value = BigUint::from_u128(22222222222222222222222222222222222222).unwrap(); + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); + while y_value > x_value { + y_value = BigUint::from_u128(rng.gen()).unwrap(); + } let expected_z_value = &x_value - &y_value; type F = CrandallField; @@ -284,8 +292,10 @@ mod tests { #[test] fn test_biguint_mul() -> Result<()> { - let x_value = BigUint::from_u128(123123123123123123123123123123123123).unwrap(); - let y_value = BigUint::from_u128(456456456456456456456456456456456456).unwrap(); + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); let expected_z_value = &x_value * &y_value; type F = CrandallField; @@ -307,8 +317,10 @@ mod tests { #[test] fn test_biguint_cmp() -> Result<()> { - let x_value = BigUint::from_u128(33333333333333333333333333333333333333).unwrap(); - let y_value = BigUint::from_u128(22222222222222222222222222222222222222).unwrap(); + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); type F = CrandallField; let config = CircuitConfig::large_config(); @@ -318,7 +330,7 @@ mod tests { let x = builder.constant_biguint(&x_value); let y = builder.constant_biguint(&y_value); let cmp = builder.cmp_biguint(&x, &y); - let expected_cmp = builder.constant_bool(false); + let expected_cmp = builder.constant_bool(x_value <= y_value); builder.connect(cmp.target, expected_cmp.target); @@ -329,8 +341,13 @@ mod tests { #[test] fn test_biguint_div_rem() -> Result<()> { - let x_value = BigUint::from_u128(456456456456456456456456456456456456).unwrap(); - let y_value = BigUint::from_u128(123123123123123123123123123123123123).unwrap(); + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); + while y_value > x_value { + y_value = BigUint::from_u128(rng.gen()).unwrap(); + } let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value); type F = CrandallField; From ea4f950d6eeb386b9b5768a9666f3681c13e614f Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 10:54:35 -0800 Subject: [PATCH 080/202] fixes and fmt --- benches/hashing.rs | 2 +- src/field/packed_avx2/mod.rs | 14 ++++----- src/gadgets/hash.rs | 6 ++-- src/gadgets/random_access.rs | 31 -------------------- src/gates/poseidon.rs | 10 +++---- src/gates/poseidon_mds.rs | 10 +++---- src/hash/poseidon.rs | 6 ++-- src/plonk/circuit_builder.rs | 57 ++++++++++++++++++++++++++++++------ 8 files changed, 72 insertions(+), 64 deletions(-) diff --git a/benches/hashing.rs b/benches/hashing.rs index c229972e..5669e50b 100644 --- a/benches/hashing.rs +++ b/benches/hashing.rs @@ -19,7 +19,7 @@ pub(crate) fn bench_gmimc, const WIDTH: usize>(c: &mut Criterion pub(crate) fn bench_poseidon, const WIDTH: usize>(c: &mut Criterion) where - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { c.bench_function(&format!("poseidon<{}, {}>", type_name::(), WIDTH), |b| { b.iter_batched( diff --git a/src/field/packed_avx2/mod.rs b/src/field/packed_avx2/mod.rs index 20eecba7..eddbb5c9 100644 --- a/src/field/packed_avx2/mod.rs +++ b/src/field/packed_avx2/mod.rs @@ -34,7 +34,7 @@ mod tests { fn test_add() where - [(); PackedPrimeField::::WIDTH]:, + [(); PackedPrimeField::::WIDTH]: , { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); @@ -52,7 +52,7 @@ mod tests { fn test_mul() where - [(); PackedPrimeField::::WIDTH]:, + [(); PackedPrimeField::::WIDTH]: , { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); @@ -70,7 +70,7 @@ mod tests { fn test_square() where - [(); PackedPrimeField::::WIDTH]:, + [(); PackedPrimeField::::WIDTH]: , { let a_arr = test_vals_a::(); @@ -86,7 +86,7 @@ mod tests { fn test_neg() where - [(); PackedPrimeField::::WIDTH]:, + [(); PackedPrimeField::::WIDTH]: , { let a_arr = test_vals_a::(); @@ -102,7 +102,7 @@ mod tests { fn test_sub() where - [(); PackedPrimeField::::WIDTH]:, + [(); PackedPrimeField::::WIDTH]: , { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); @@ -120,7 +120,7 @@ mod tests { fn test_interleave_is_involution() where - [(); PackedPrimeField::::WIDTH]:, + [(); PackedPrimeField::::WIDTH]: , { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); @@ -144,7 +144,7 @@ mod tests { fn test_interleave() where - [(); PackedPrimeField::::WIDTH]:, + [(); PackedPrimeField::::WIDTH]: , { let in_a: [F; 4] = [ F::from_noncanonical_u64(00), diff --git a/src/gadgets/hash.rs b/src/gadgets/hash.rs index db4cb1e8..99da9e1e 100644 --- a/src/gadgets/hash.rs +++ b/src/gadgets/hash.rs @@ -15,7 +15,7 @@ impl, const D: usize> CircuitBuilder { pub fn permute(&mut self, inputs: [Target; W]) -> [Target; W] where F: GMiMC + Poseidon, - [(); W - 1]:, + [(); W - 1]: , { // We don't want to swap any inputs, so set that wire to 0. let _false = self._false(); @@ -31,7 +31,7 @@ impl, const D: usize> CircuitBuilder { ) -> [Target; W] where F: GMiMC + Poseidon, - [(); W - 1]:, + [(); W - 1]: , { match HASH_FAMILY { HashFamily::GMiMC => self.gmimc_permute_swapped(inputs, swap), @@ -88,7 +88,7 @@ impl, const D: usize> CircuitBuilder { ) -> [Target; W] where F: Poseidon, - [(); W - 1]:, + [(); W - 1]: , { let gate_type = PoseidonGate::::new(); let gate = self.add_gate(gate_type, vec![]); diff --git a/src/gadgets/random_access.rs b/src/gadgets/random_access.rs index 398c516f..58c827c1 100644 --- a/src/gadgets/random_access.rs +++ b/src/gadgets/random_access.rs @@ -6,37 +6,6 @@ use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { - /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. - /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index - /// `g` and the gate's `i`-th random access is available. - fn find_random_access_gate(&mut self, vec_size: usize) -> (usize, usize) { - let (gate, i) = self - .free_random_access - .get(&vec_size) - .copied() - .unwrap_or_else(|| { - let gate = self.add_gate( - RandomAccessGate::new_from_config(&self.config, vec_size), - vec![], - ); - (gate, 0) - }); - - // Update `free_random_access` with new values. - if i < RandomAccessGate::::max_num_copies( - self.config.num_routed_wires, - self.config.num_wires, - vec_size, - ) - 1 - { - self.free_random_access.insert(vec_size, (gate, i + 1)); - } else { - self.free_random_access.remove(&vec_size); - } - - (gate, i) - } - /// Checks that a `Target` matches a vector at a non-deterministic index. /// Note: `access_index` is not range-checked. pub fn random_access(&mut self, access_index: Target, claimed_element: Target, v: Vec) { diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 6e1eb69a..1f5f746d 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -26,7 +26,7 @@ pub struct PoseidonGate< const D: usize, const WIDTH: usize, > where - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { _phantom: PhantomData, } @@ -34,7 +34,7 @@ pub struct PoseidonGate< impl + Poseidon, const D: usize, const WIDTH: usize> PoseidonGate where - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { pub fn new() -> Self { PoseidonGate { @@ -91,7 +91,7 @@ where impl + Poseidon, const D: usize, const WIDTH: usize> Gate for PoseidonGate where - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { fn id(&self) -> String { format!("{:?}", self, WIDTH) @@ -396,7 +396,7 @@ struct PoseidonGenerator< const D: usize, const WIDTH: usize, > where - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { gate_index: usize, _phantom: PhantomData, @@ -405,7 +405,7 @@ struct PoseidonGenerator< impl + Poseidon, const D: usize, const WIDTH: usize> SimpleGenerator for PoseidonGenerator where - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { fn dependencies(&self) -> Vec { (0..WIDTH) diff --git a/src/gates/poseidon_mds.rs b/src/gates/poseidon_mds.rs index a127df68..8a42b588 100644 --- a/src/gates/poseidon_mds.rs +++ b/src/gates/poseidon_mds.rs @@ -21,7 +21,7 @@ pub struct PoseidonMdsGate< const D: usize, const WIDTH: usize, > where - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { _phantom: PhantomData, } @@ -29,7 +29,7 @@ pub struct PoseidonMdsGate< impl + Poseidon, const D: usize, const WIDTH: usize> PoseidonMdsGate where - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { pub fn new() -> Self { PoseidonMdsGate { @@ -116,7 +116,7 @@ where impl + Poseidon, const D: usize, const WIDTH: usize> Gate for PoseidonMdsGate where - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { fn id(&self) -> String { format!("{:?}", self, WIDTH) @@ -207,7 +207,7 @@ where #[derive(Clone, Debug)] struct PoseidonMdsGenerator where - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { gate_index: usize, } @@ -215,7 +215,7 @@ where impl + Poseidon, const D: usize, const WIDTH: usize> SimpleGenerator for PoseidonMdsGenerator where - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { fn dependencies(&self) -> Vec { (0..WIDTH) diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 9e4dd7f4..9a52060c 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -147,7 +147,7 @@ pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ pub trait Poseidon: PrimeField where // magic to get const generic expressions to work - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { // Total number of round constants required: width of the input // times number of rounds. @@ -634,7 +634,7 @@ pub(crate) mod test_helpers { test_vectors: Vec<([u64; WIDTH], [u64; WIDTH])>, ) where F: Poseidon, - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { for (input_, expected_output_) in test_vectors.into_iter() { let mut input = [F::ZERO; WIDTH]; @@ -652,7 +652,7 @@ pub(crate) mod test_helpers { pub(crate) fn check_consistency() where F: Poseidon, - [(); WIDTH - 1]:, + [(); WIDTH - 1]: , { let mut input = [F::ZERO; WIDTH]; for i in 0..WIDTH { diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 33709e32..c063aa85 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -77,7 +77,7 @@ pub struct CircuitBuilder, const D: usize> { /// Memoized results of `arithmetic_extension` calls. pub(crate) arithmetic_results: HashMap, ExtensionTarget>, - batched_gates: BatchedGates + batched_gates: BatchedGates, } impl, const D: usize> CircuitBuilder { @@ -295,7 +295,7 @@ impl, const D: usize> CircuitBuilder { return target; } - let (gate, instance) = self.batched_gates.constant_gate_instance(); + let (gate, instance) = self.constant_gate_instance(); let target = Target::wire(gate, instance); self.gate_instances[gate].constants[instance] = c; @@ -748,6 +748,10 @@ pub struct BatchedGates, const D: usize> { /// these constants with gate index `g` and already using `i` arithmetic operations. pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, + /// A map `(c0, c1) -> (g, i)` from constants `vec_size` to an available arithmetic gate using + /// these constants with gate index `g` and already using `i` random accesses. + pub(crate) free_random_access: HashMap, + /// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value /// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies /// of switches @@ -767,6 +771,7 @@ impl, const D: usize> BatchedGates { pub fn new() -> Self { Self { free_arithmetic: HashMap::new(), + free_random_access: HashMap::new(), current_switch_gates: Vec::new(), current_u32_arithmetic_gate: None, current_u32_subtraction_gate: None, @@ -807,6 +812,40 @@ impl, const D: usize> CircuitBuilder { (gate, i) } + /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. + /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index + /// `g` and the gate's `i`-th random access is available. + pub(crate) fn find_random_access_gate(&mut self, vec_size: usize) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_random_access + .get(&vec_size) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + RandomAccessGate::new_from_config(&self.config, vec_size), + vec![], + ); + (gate, 0) + }); + + // Update `free_random_access` with new values. + if i < RandomAccessGate::::max_num_copies( + self.config.num_routed_wires, + self.config.num_wires, + vec_size, + ) - 1 + { + self.batched_gates + .free_random_access + .insert(vec_size, (gate, i + 1)); + } else { + self.batched_gates.free_random_access.remove(&vec_size); + } + + (gate, i) + } + pub(crate) fn find_switch_gate( &mut self, chunk_size: usize, @@ -825,7 +864,7 @@ impl, const D: usize> CircuitBuilder { let (gate, gate_index, next_copy) = match self.batched_gates.current_switch_gates[chunk_size - 1].clone() { None => { - let gate = SwitchGate::::new_from_config(self.config.clone(), chunk_size); + let gate = SwitchGate::::new_from_config(&self.config, chunk_size); let gate_index = self.add_gate(gate.clone(), vec![]); (gate, gate_index, 0) } @@ -885,23 +924,23 @@ impl, const D: usize> CircuitBuilder { /// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a /// new `ConstantGate` if needed. fn constant_gate_instance(&mut self) -> (usize, usize) { - if self.free_constant.is_none() { + if self.batched_gates.free_constant.is_none() { let num_consts = self.config.constant_gate_size; // We will fill this `ConstantGate` with zero constants initially. // These will be overwritten by `constant` as the gate instances are filled. let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]); - self.free_constant = Some((gate, 0)); + self.batched_gates.free_constant = Some((gate, 0)); } - let (gate, instance) = self.free_constant.unwrap(); + let (gate, instance) = self.batched_gates.free_constant.unwrap(); if instance + 1 < self.config.constant_gate_size { - self.free_constant = Some((gate, instance + 1)); + self.batched_gates.free_constant = Some((gate, instance + 1)); } else { - self.free_constant = None; + self.batched_gates.free_constant = None; } (gate, instance) } - + /// Fill the remaining unused arithmetic operations with zeros, so that all /// `ArithmeticExtensionGenerator`s are run. fn fill_arithmetic_gates(&mut self) { From 9043a47e1be21e5859d496293411d5d893c2302b Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 11:15:00 -0800 Subject: [PATCH 081/202] more fixes --- src/gadgets/biguint.rs | 22 +++++++++++----------- src/gadgets/multiple_comparison.rs | 6 +++--- src/gadgets/nonnative.rs | 18 +++++++++--------- src/gates/assert_le.rs | 14 ++++++++------ src/gates/subtraction_u32.rs | 10 +++++----- 5 files changed, 36 insertions(+), 34 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index ad61c8a0..f9878c50 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -232,7 +232,7 @@ mod tests { use rand::Rng; use crate::{ - field::crandall_field::CrandallField, + field::goldilocks_field::GoldilocksField, iop::witness::PartialWitness, plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify}, }; @@ -245,8 +245,8 @@ mod tests { let y_value = BigUint::from_u128(rng.gen()).unwrap(); let expected_z_value = &x_value + &y_value; - type F = CrandallField; - let config = CircuitConfig::large_config(); + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -273,8 +273,8 @@ mod tests { } let expected_z_value = &x_value - &y_value; - type F = CrandallField; - let config = CircuitConfig::large_config(); + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -298,8 +298,8 @@ mod tests { let y_value = BigUint::from_u128(rng.gen()).unwrap(); let expected_z_value = &x_value * &y_value; - type F = CrandallField; - let config = CircuitConfig::large_config(); + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -322,8 +322,8 @@ mod tests { let x_value = BigUint::from_u128(rng.gen()).unwrap(); let y_value = BigUint::from_u128(rng.gen()).unwrap(); - type F = CrandallField; - let config = CircuitConfig::large_config(); + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -350,8 +350,8 @@ mod tests { } let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value); - type F = CrandallField; - let config = CircuitConfig::large_config(); + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index 553fb49e..77e660e6 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -71,16 +71,16 @@ mod tests { use num::BigUint; use rand::Rng; - use crate::field::crandall_field::CrandallField; use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::verifier::verify; fn test_list_le(size: usize, num_bits: usize) -> Result<()> { - type F = CrandallField; - let config = CircuitConfig::large_config(); + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 1d14f708..fbe46fe2 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -115,8 +115,8 @@ impl, const D: usize> CircuitBuilder { mod tests { use anyhow::Result; - use crate::field::crandall_field::CrandallField; use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; use crate::field::secp256k1::Secp256K1Base; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; @@ -130,8 +130,8 @@ mod tests { let y_ff = FF::rand(); let sum_ff = x_ff + y_ff; - type F = CrandallField; - let config = CircuitConfig::large_config(); + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -157,8 +157,8 @@ mod tests { } let diff_ff = x_ff - y_ff; - type F = CrandallField; - let config = CircuitConfig::large_config(); + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -181,8 +181,8 @@ mod tests { let y_ff = FF::rand(); let product_ff = x_ff * y_ff; - type F = CrandallField; - let config = CircuitConfig::large_config(); + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -204,8 +204,8 @@ mod tests { let x_ff = FF::rand(); let neg_x_ff = -x_ff; - type F = CrandallField; - let config = CircuitConfig::large_config(); + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); diff --git a/src/gates/assert_le.rs b/src/gates/assert_le.rs index 98411ef2..4d33a867 100644 --- a/src/gates/assert_le.rs +++ b/src/gates/assert_le.rs @@ -432,9 +432,9 @@ mod tests { use anyhow::Result; use rand::Rng; - use crate::field::crandall_field::CrandallField; use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::{Field, PrimeField}; + use crate::field::goldilocks_field::GoldilocksField; use crate::gates::assert_le::AssertLessThanGate; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; @@ -443,7 +443,7 @@ mod tests { #[test] fn wire_indices() { - type AG = AssertLessThanGate; + type AG = AssertLessThanGate; let num_bits = 40; let num_chunks = 5; @@ -473,7 +473,7 @@ mod tests { let num_bits = 40; let num_chunks = 5; - test_low_degree::(AssertLessThanGate::<_, 4>::new( + test_low_degree::(AssertLessThanGate::<_, 4>::new( num_bits, num_chunks, )) } @@ -483,13 +483,15 @@ mod tests { let num_bits = 40; let num_chunks = 5; - test_eval_fns::(AssertLessThanGate::<_, 4>::new(num_bits, num_chunks)) + test_eval_fns::(AssertLessThanGate::<_, 4>::new( + num_bits, num_chunks, + )) } #[test] fn test_gate_constraint() { - type F = CrandallField; - type FF = QuarticExtension; + type F = GoldilocksField; + type FF = QuarticExtension; const D: usize = 4; let num_bits = 40; diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs index 14447727..225c09e4 100644 --- a/src/gates/subtraction_u32.rs +++ b/src/gates/subtraction_u32.rs @@ -317,9 +317,9 @@ mod tests { use anyhow::Result; use rand::Rng; - use crate::field::crandall_field::CrandallField; use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::{Field, PrimeField}; + use crate::field::goldilocks_field::GoldilocksField; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; @@ -328,22 +328,22 @@ mod tests { #[test] fn low_degree() { - test_low_degree::(U32SubtractionGate:: { + test_low_degree::(U32SubtractionGate:: { _phantom: PhantomData, }) } #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(U32SubtractionGate:: { + test_eval_fns::(U32SubtractionGate:: { _phantom: PhantomData, }) } #[test] fn test_gate_constraint() { - type F = CrandallField; - type FF = QuarticExtension; + type F = GoldilocksField; + type FF = QuarticExtension; const D: usize = 4; fn get_wires(inputs_x: Vec, inputs_y: Vec, borrows: Vec) -> Vec { From dd945ef5b73bac72929677f598ac98cfc806f166 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 11:19:06 -0800 Subject: [PATCH 082/202] addressed comments --- src/gadgets/biguint.rs | 13 ++++++------- src/gadgets/nonnative.rs | 14 ++++---------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index f9878c50..fff97a6e 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -61,7 +61,6 @@ impl, const D: usize> CircuitBuilder { (a.clone(), padded_b) } else { let mut padded_a = a.clone(); - let to_extend = b.num_limbs() - a.num_limbs(); for _ in a.num_limbs()..b.num_limbs() { padded_a.limbs.push(self.zero_u32()); } @@ -266,10 +265,10 @@ mod tests { fn test_biguint_sub() -> Result<()> { let mut rng = rand::thread_rng(); - let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); - while y_value > x_value { - y_value = BigUint::from_u128(rng.gen()).unwrap(); + if y_value > x_value { + (x_value, y_value) = (y_value, x_value); } let expected_z_value = &x_value - &y_value; @@ -343,10 +342,10 @@ mod tests { fn test_biguint_div_rem() -> Result<()> { let mut rng = rand::thread_rng(); - let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); - while y_value > x_value { - y_value = BigUint::from_u128(rng.gen()).unwrap(); + if y_value > x_value { + (x_value, y_value) = (y_value, x_value); } let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value); diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index fbe46fe2..fd883e5d 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -44,9 +44,7 @@ impl, const D: usize> CircuitBuilder { a: &ForeignFieldTarget, b: &ForeignFieldTarget, ) -> ForeignFieldTarget { - let a_biguint = self.nonnative_to_biguint(a); - let b_biguint = self.nonnative_to_biguint(b); - let result = self.add_biguint(&a_biguint, &b_biguint); + let result = self.add_biguint(&a.value, &b.value); self.reduce(&result) } @@ -58,10 +56,8 @@ impl, const D: usize> CircuitBuilder { b: &ForeignFieldTarget, ) -> ForeignFieldTarget { let order = self.constant_biguint(&FF::order()); - let a_biguint = self.nonnative_to_biguint(a); - let a_plus_order = self.add_biguint(&order, &a_biguint); - let b_biguint = self.nonnative_to_biguint(b); - let result = self.sub_biguint(&a_plus_order, &b_biguint); + let a_plus_order = self.add_biguint(&order, &a.value); + let result = self.sub_biguint(&a_plus_order, &b.value); // TODO: reduce sub result with only one conditional addition? self.reduce(&result) @@ -72,9 +68,7 @@ impl, const D: usize> CircuitBuilder { a: &ForeignFieldTarget, b: &ForeignFieldTarget, ) -> ForeignFieldTarget { - let a_biguint = self.nonnative_to_biguint(a); - let b_biguint = self.nonnative_to_biguint(b); - let result = self.mul_biguint(&a_biguint, &b_biguint); + let result = self.mul_biguint(&a.value, &b.value); self.reduce(&result) } From bd427cd62944fd0c6355cd653aabee38af02144c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 12:10:32 -0800 Subject: [PATCH 083/202] fixed failing tests --- src/plonk/circuit_builder.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index c063aa85..c6c49826 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -973,6 +973,22 @@ impl, const D: usize> CircuitBuilder { } } + /// Fill the remaining unused random access operations with zeros, so that all + /// `RandomAccessGenerator`s are run. + fn fill_random_access_gates(&mut self) { + let zero = self.zero(); + for (vec_size, (_, i)) in self.batched_gates.free_random_access.clone() { + let max_copies = RandomAccessGate::::max_num_copies( + self.config.num_routed_wires, + self.config.num_wires, + vec_size, + ); + for _ in i..max_copies { + self.random_access(zero, zero, vec![zero; vec_size]); + } + } + } + /// Fill the remaining unused switch gates with dummy values, so that all /// `SwitchGenerator`s are run. fn fill_switch_gates(&mut self) { @@ -1048,6 +1064,7 @@ impl, const D: usize> CircuitBuilder { fn fill_batched_gates(&mut self) { self.fill_arithmetic_gates(); + self.fill_random_access_gates(); self.fill_switch_gates(); self.fill_u32_arithmetic_gates(); self.fill_u32_subtraction_gates(); From f2ec2cadf4d46d01ed03989f09d60d4b97d557db Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 12:14:23 -0800 Subject: [PATCH 084/202] new fmt --- benches/hashing.rs | 2 +- src/field/packed_avx2/mod.rs | 14 +++++++------- src/gadgets/hash.rs | 6 +++--- src/gates/poseidon.rs | 10 +++++----- src/gates/poseidon_mds.rs | 10 +++++----- src/hash/poseidon.rs | 6 +++--- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/benches/hashing.rs b/benches/hashing.rs index 5669e50b..c229972e 100644 --- a/benches/hashing.rs +++ b/benches/hashing.rs @@ -19,7 +19,7 @@ pub(crate) fn bench_gmimc, const WIDTH: usize>(c: &mut Criterion pub(crate) fn bench_poseidon, const WIDTH: usize>(c: &mut Criterion) where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { c.bench_function(&format!("poseidon<{}, {}>", type_name::(), WIDTH), |b| { b.iter_batched( diff --git a/src/field/packed_avx2/mod.rs b/src/field/packed_avx2/mod.rs index eddbb5c9..20eecba7 100644 --- a/src/field/packed_avx2/mod.rs +++ b/src/field/packed_avx2/mod.rs @@ -34,7 +34,7 @@ mod tests { fn test_add() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); @@ -52,7 +52,7 @@ mod tests { fn test_mul() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); @@ -70,7 +70,7 @@ mod tests { fn test_square() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); @@ -86,7 +86,7 @@ mod tests { fn test_neg() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); @@ -102,7 +102,7 @@ mod tests { fn test_sub() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); @@ -120,7 +120,7 @@ mod tests { fn test_interleave_is_involution() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); @@ -144,7 +144,7 @@ mod tests { fn test_interleave() where - [(); PackedPrimeField::::WIDTH]: , + [(); PackedPrimeField::::WIDTH]:, { let in_a: [F; 4] = [ F::from_noncanonical_u64(00), diff --git a/src/gadgets/hash.rs b/src/gadgets/hash.rs index 99da9e1e..db4cb1e8 100644 --- a/src/gadgets/hash.rs +++ b/src/gadgets/hash.rs @@ -15,7 +15,7 @@ impl, const D: usize> CircuitBuilder { pub fn permute(&mut self, inputs: [Target; W]) -> [Target; W] where F: GMiMC + Poseidon, - [(); W - 1]: , + [(); W - 1]:, { // We don't want to swap any inputs, so set that wire to 0. let _false = self._false(); @@ -31,7 +31,7 @@ impl, const D: usize> CircuitBuilder { ) -> [Target; W] where F: GMiMC + Poseidon, - [(); W - 1]: , + [(); W - 1]:, { match HASH_FAMILY { HashFamily::GMiMC => self.gmimc_permute_swapped(inputs, swap), @@ -88,7 +88,7 @@ impl, const D: usize> CircuitBuilder { ) -> [Target; W] where F: Poseidon, - [(); W - 1]: , + [(); W - 1]:, { let gate_type = PoseidonGate::::new(); let gate = self.add_gate(gate_type, vec![]); diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 1f5f746d..6e1eb69a 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -26,7 +26,7 @@ pub struct PoseidonGate< const D: usize, const WIDTH: usize, > where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { _phantom: PhantomData, } @@ -34,7 +34,7 @@ pub struct PoseidonGate< impl + Poseidon, const D: usize, const WIDTH: usize> PoseidonGate where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { pub fn new() -> Self { PoseidonGate { @@ -91,7 +91,7 @@ where impl + Poseidon, const D: usize, const WIDTH: usize> Gate for PoseidonGate where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { fn id(&self) -> String { format!("{:?}", self, WIDTH) @@ -396,7 +396,7 @@ struct PoseidonGenerator< const D: usize, const WIDTH: usize, > where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { gate_index: usize, _phantom: PhantomData, @@ -405,7 +405,7 @@ struct PoseidonGenerator< impl + Poseidon, const D: usize, const WIDTH: usize> SimpleGenerator for PoseidonGenerator where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { fn dependencies(&self) -> Vec { (0..WIDTH) diff --git a/src/gates/poseidon_mds.rs b/src/gates/poseidon_mds.rs index 8a42b588..a127df68 100644 --- a/src/gates/poseidon_mds.rs +++ b/src/gates/poseidon_mds.rs @@ -21,7 +21,7 @@ pub struct PoseidonMdsGate< const D: usize, const WIDTH: usize, > where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { _phantom: PhantomData, } @@ -29,7 +29,7 @@ pub struct PoseidonMdsGate< impl + Poseidon, const D: usize, const WIDTH: usize> PoseidonMdsGate where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { pub fn new() -> Self { PoseidonMdsGate { @@ -116,7 +116,7 @@ where impl + Poseidon, const D: usize, const WIDTH: usize> Gate for PoseidonMdsGate where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { fn id(&self) -> String { format!("{:?}", self, WIDTH) @@ -207,7 +207,7 @@ where #[derive(Clone, Debug)] struct PoseidonMdsGenerator where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { gate_index: usize, } @@ -215,7 +215,7 @@ where impl + Poseidon, const D: usize, const WIDTH: usize> SimpleGenerator for PoseidonMdsGenerator where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { fn dependencies(&self) -> Vec { (0..WIDTH) diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 9a52060c..9e4dd7f4 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -147,7 +147,7 @@ pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ pub trait Poseidon: PrimeField where // magic to get const generic expressions to work - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { // Total number of round constants required: width of the input // times number of rounds. @@ -634,7 +634,7 @@ pub(crate) mod test_helpers { test_vectors: Vec<([u64; WIDTH], [u64; WIDTH])>, ) where F: Poseidon, - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { for (input_, expected_output_) in test_vectors.into_iter() { let mut input = [F::ZERO; WIDTH]; @@ -652,7 +652,7 @@ pub(crate) mod test_helpers { pub(crate) fn check_consistency() where F: Poseidon, - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { let mut input = [F::ZERO; WIDTH]; for i in 0..WIDTH { From 9139d1350a0fb464af3ad729660ddfbafc4409ba Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Thu, 11 Nov 2021 07:16:16 -0800 Subject: [PATCH 085/202] Minor refactor of partial product code (#351) --- src/plonk/prover.rs | 89 ++++++++++++------------------------ src/plonk/vanishing_poly.rs | 38 ++++++--------- src/util/partial_products.rs | 79 +++++++++++++++++--------------- 3 files changed, 83 insertions(+), 123 deletions(-) diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index be356d9f..1c247998 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -17,9 +17,10 @@ use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch; use crate::plonk::vars::EvaluationVarsBase; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; -use crate::util::partial_products::partial_products; +use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_products}; use crate::util::timing::TimingTree; use crate::util::{log2_ceil, transpose}; +use std::mem::swap; pub(crate) fn prove, const D: usize>( prover_data: &ProverOnlyCircuitData, @@ -91,28 +92,21 @@ pub(crate) fn prove, const D: usize>( common_data.quotient_degree_factor < common_data.config.num_routed_wires, "When the number of routed wires is smaller that the degree, we should change the logic to avoid computing partial products." ); - let mut partial_products = timed!( + let mut partial_products_and_zs = timed!( timing, "compute partial products", all_wires_permutation_partial_products(&witness, &betas, &gammas, prover_data, common_data) ); - let plonk_z_vecs = timed!( - timing, - "compute Z's", - compute_zs(&mut partial_products, common_data) - ); + // Z is expected at the front of our batch; see `zs_range` and `partial_products_range`. + let plonk_z_vecs = partial_products_and_zs.iter_mut() + .map(|partial_products_and_z| partial_products_and_z.pop().unwrap()) + .collect(); + let zs_partial_products = [plonk_z_vecs, partial_products_and_zs.concat()].concat(); - // The first polynomial in `partial_products` represent the final product used in the - // computation of `Z`. It isn't needed anymore so we discard it. - partial_products.iter_mut().for_each(|part| { - part.remove(0); - }); - - let zs_partial_products = [plonk_z_vecs, partial_products.concat()].concat(); - let zs_partial_products_commitment = timed!( + let partial_products_and_zs_commitment = timed!( timing, - "commit to Z's", + "commit to partial products and Z's", PolynomialBatchCommitment::from_values( zs_partial_products, config.rate_bits, @@ -123,7 +117,7 @@ pub(crate) fn prove, const D: usize>( ) ); - challenger.observe_cap(&zs_partial_products_commitment.merkle_tree.cap); + challenger.observe_cap(&partial_products_and_zs_commitment.merkle_tree.cap); let alphas = challenger.get_n_challenges(num_challenges); @@ -135,7 +129,7 @@ pub(crate) fn prove, const D: usize>( prover_data, &public_inputs_hash, &wires_commitment, - &zs_partial_products_commitment, + &partial_products_and_zs_commitment, &betas, &gammas, &alphas, @@ -184,7 +178,7 @@ pub(crate) fn prove, const D: usize>( &[ &prover_data.constants_sigmas_commitment, &wires_commitment, - &zs_partial_products_commitment, + &partial_products_and_zs_commitment, "ient_polys_commitment, ], zeta, @@ -196,7 +190,7 @@ pub(crate) fn prove, const D: usize>( let proof = Proof { wires_cap: wires_commitment.merkle_tree.cap, - plonk_zs_partial_products_cap: zs_partial_products_commitment.merkle_tree.cap, + plonk_zs_partial_products_cap: partial_products_and_zs_commitment.merkle_tree.cap, quotient_polys_cap: quotient_polys_commitment.merkle_tree.cap, openings, opening_proof, @@ -217,7 +211,7 @@ fn all_wires_permutation_partial_products, const D: ) -> Vec>> { (0..common_data.config.num_challenges) .map(|i| { - wires_permutation_partial_products( + wires_permutation_partial_products_and_zs( witness, betas[i], gammas[i], @@ -231,7 +225,7 @@ fn all_wires_permutation_partial_products, const D: /// Compute the partial products used in the `Z` polynomial. /// Returns the polynomials interpolating `partial_products(f / g)` /// where `f, g` are the products in the definition of `Z`: `Z(g^i) = f / g`. -fn wires_permutation_partial_products, const D: usize>( +fn wires_permutation_partial_products_and_zs, const D: usize>( witness: &MatrixWitness, beta: F, gamma: F, @@ -241,7 +235,8 @@ fn wires_permutation_partial_products, const D: usi let degree = common_data.quotient_degree_factor; let subgroup = &prover_data.subgroup; let k_is = &common_data.k_is; - let values = subgroup + let (num_prods, final_num_prod) = common_data.num_partial_products; + let all_quotient_chunk_products = subgroup .par_iter() .enumerate() .map(|(i, &x)| { @@ -265,51 +260,25 @@ fn wires_permutation_partial_products, const D: usi .map(|(num, den_inv)| num * den_inv) .collect::>(); - let quotient_partials = partial_products("ient_values, degree); - - // This is the final product for the quotient. - let quotient = *quotient_partials.last().unwrap() - * quotient_values[common_data.num_partial_products.1..] - .iter() - .copied() - .product(); - - // We add the quotient at the beginning of the vector to reuse them later in the computation of `Z`. - [vec![quotient], quotient_partials].concat() + quotient_chunk_products("ient_values, degree) }) .collect::>(); - transpose(&values) + let mut z_x = F::ONE; + let mut all_partial_products_and_zs = Vec::new(); + for quotient_chunk_products in all_quotient_chunk_products { + let mut partial_products_and_z_gx = partial_products_and_z_gx(z_x, "ient_chunk_products); + // The last term is Z(gx), but we replace it with Z(x), otherwise Z would end up shifted. + swap(&mut z_x, &mut partial_products_and_z_gx[num_prods]); + all_partial_products_and_zs.push(partial_products_and_z_gx); + } + + transpose(&all_partial_products_and_zs) .into_par_iter() .map(PolynomialValues::new) .collect() } -fn compute_zs, const D: usize>( - partial_products: &mut [Vec>], - common_data: &CommonCircuitData, -) -> Vec> { - (0..common_data.config.num_challenges) - .map(|i| compute_z(&mut partial_products[i], common_data)) - .collect() -} - -/// Compute the `Z` polynomial by reusing the computations done in `wires_permutation_partial_products`. -fn compute_z, const D: usize>( - partial_products: &mut [PolynomialValues], - common_data: &CommonCircuitData, -) -> PolynomialValues { - let mut plonk_z_points = vec![F::ONE]; - for i in 1..common_data.degree() { - let last = *plonk_z_points.last().unwrap(); - for q in partial_products.iter_mut() { - q.values[i - 1] *= last; - } - plonk_z_points.push(partial_products[0].values[i - 1]); - } - plonk_z_points.into() -} - const BATCH_SIZE: usize = 32; fn compute_quotient_polys<'a, F: RichField + Extendable, const D: usize>( diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 899d69a6..f91e027b 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -75,13 +75,10 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( ); vanishing_partial_products_terms.extend(partial_product_checks); - let v_shift_term = *current_partial_products.last().unwrap() - * numerator_values[final_num_prod..].iter().copied().product() - - z_gz - * denominator_values[final_num_prod..] - .iter() - .copied() - .product(); + let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); + let final_deno_product = denominator_values[final_num_prod..].iter().copied().product(); + let last_partial = *current_partial_products.last().unwrap(); + let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; vanishing_v_shift_terms.push(v_shift_term); } @@ -185,13 +182,10 @@ pub(crate) fn eval_vanishing_poly_base_batch, const ); vanishing_partial_products_terms.extend(partial_product_checks); - let v_shift_term = *current_partial_products.last().unwrap() - * numerator_values[final_num_prod..].iter().copied().product() - - z_gz - * denominator_values[final_num_prod..] - .iter() - .copied() - .product(); + let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); + let final_deno_product = denominator_values[final_num_prod..].iter().copied().product(); + let last_partial = *current_partial_products.last().unwrap(); + let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; vanishing_v_shift_terms.push(v_shift_term); numerator_values.clear(); @@ -381,17 +375,11 @@ pub(crate) fn eval_vanishing_poly_recursively, cons ); vanishing_partial_products_terms.extend(partial_product_checks); - let nume_acc = builder.mul_many_extension(&{ - let mut v = numerator_values[final_num_prod..].to_vec(); - v.push(*current_partial_products.last().unwrap()); - v - }); - let z_gz_denominators = builder.mul_many_extension(&{ - let mut v = denominator_values[final_num_prod..].to_vec(); - v.push(z_gz); - v - }); - let v_shift_term = builder.sub_extension(nume_acc, z_gz_denominators); + let final_nume_product = builder.mul_many_extension(&numerator_values[final_num_prod..]); + let final_deno_product = builder.mul_many_extension(&denominator_values[final_num_prod..]); + let z_gz_denominators = builder.mul_extension(z_gz, final_deno_product); + let last_partial = *current_partial_products.last().unwrap(); + let v_shift_term = builder.mul_sub_extension(last_partial, final_nume_product, z_gz_denominators); vanishing_v_shift_terms.push(v_shift_term); } diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 92419a56..1e7a4f4b 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -2,19 +2,30 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::plonk::circuit_builder::CircuitBuilder; +use itertools::Itertools; + +pub(crate) fn quotient_chunk_products( + quotient_values: &[F], + max_degree: usize, +) -> Vec { + debug_assert!(max_degree > 1); + assert!(quotient_values.len() > 0); + let chunk_size = max_degree; + quotient_values.chunks(chunk_size) + .map(|chunk| chunk.iter().copied().product()) + .collect() +} /// Compute partial products of the original vector `v` such that all products consist of `max_degree` /// or less elements. This is done until we've computed the product `P` of all elements in the vector. -pub fn partial_products(v: &[F], max_degree: usize) -> Vec { - debug_assert!(max_degree > 1); +pub(crate) fn partial_products_and_z_gx(z_x: F, quotient_chunk_products: &[F]) -> Vec { + assert!(quotient_chunk_products.len() > 0); let mut res = Vec::new(); - let mut acc = F::ONE; - let chunk_size = max_degree; - for chunk in v.chunks_exact(chunk_size) { - acc *= chunk.iter().copied().product(); + let mut acc = z_x; + for "ient_chunk_product in quotient_chunk_products { + acc *= quotient_chunk_product; res.push(acc); } - res } @@ -30,24 +41,26 @@ pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { /// Checks that the partial products of `numerators/denominators` are coherent with those in `partials` by only computing /// products of size `max_degree` or less. -pub fn check_partial_products( +pub(crate) fn check_partial_products( numerators: &[F], denominators: &[F], partials: &[F], - mut acc: F, + z_x: F, max_degree: usize, ) -> Vec { debug_assert!(max_degree > 1); + let mut acc = z_x; let mut partials = partials.iter(); let mut res = Vec::new(); let chunk_size = max_degree; for (nume_chunk, deno_chunk) in numerators .chunks_exact(chunk_size) - .zip(denominators.chunks_exact(chunk_size)) + .zip_eq(denominators.chunks_exact(chunk_size)) { - acc *= nume_chunk.iter().copied().product(); - let mut new_acc = *partials.next().unwrap(); - res.push(acc - new_acc * deno_chunk.iter().copied().product()); + let num_chunk_product = nume_chunk.iter().copied().product(); + let den_chunk_product = deno_chunk.iter().copied().product(); + let new_acc = *partials.next().unwrap(); + res.push(acc * num_chunk_product - new_acc * den_chunk_product); acc = new_acc; } debug_assert!(partials.next().is_none()); @@ -55,7 +68,7 @@ pub fn check_partial_products( res } -pub fn check_partial_products_recursively, const D: usize>( +pub(crate) fn check_partial_products_recursively, const D: usize>( builder: &mut CircuitBuilder, numerators: &[ExtensionTarget], denominators: &[ExtensionTarget], @@ -93,18 +106,11 @@ mod tests { fn test_partial_products() { type F = GoldilocksField; let denominators = vec![F::ONE; 6]; - let v = [1, 2, 3, 4, 5, 6] - .into_iter() - .map(|&i| F::from_canonical_u64(i)) - .collect::>(); - let p = partial_products(&v, 2); - assert_eq!( - p, - [2, 24, 720] - .into_iter() - .map(|&i| F::from_canonical_u64(i)) - .collect::>() - ); + let v = field_vec(&[1, 2, 3, 4, 5, 6]); + let quotient_chunks_prods = quotient_chunk_products(&v, 2); + assert_eq!(quotient_chunks_prods, field_vec(&[2, 12, 30])); + let p = partial_products_and_z_gx(F::ONE, "ient_chunks_prods); + assert_eq!(p, field_vec(&[2, 24, 720])); let nums = num_partial_products(v.len(), 2); assert_eq!(p.len(), nums.0); @@ -116,18 +122,11 @@ mod tests { v.into_iter().product::(), ); - let v = [1, 2, 3, 4, 5, 6] - .into_iter() - .map(|&i| F::from_canonical_u64(i)) - .collect::>(); - let p = partial_products(&v, 3); - assert_eq!( - p, - [6, 720] - .into_iter() - .map(|&i| F::from_canonical_u64(i)) - .collect::>() - ); + let v = field_vec(&[1, 2, 3, 4, 5, 6]); + let quotient_chunks_prods = quotient_chunk_products(&v, 3); + assert_eq!(quotient_chunks_prods, field_vec(&[6, 120])); + let p = partial_products_and_z_gx(F::ONE, "ient_chunks_prods); + assert_eq!(p, field_vec(&[6, 720])); let nums = num_partial_products(v.len(), 3); assert_eq!(p.len(), nums.0); assert!(check_partial_products(&v, &denominators, &p, F::ONE, 3) @@ -138,4 +137,8 @@ mod tests { v.into_iter().product::(), ); } + + fn field_vec(xs: &[usize]) -> Vec { + xs.iter().map(|&x| F::from_canonical_usize(x)).collect() + } } From 21d3b127e3eb288019bc2047af4e30ce15631e62 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 12 Nov 2021 09:15:37 +0100 Subject: [PATCH 086/202] Cargo fmt --- src/plonk/prover.rs | 9 ++++++--- src/plonk/vanishing_poly.rs | 13 ++++++++++--- src/util/partial_products.rs | 6 ++++-- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 1c247998..1dd17cb8 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -1,3 +1,5 @@ +use std::mem::swap; + use anyhow::Result; use rayon::prelude::*; @@ -20,7 +22,6 @@ use crate::timed; use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_products}; use crate::util::timing::TimingTree; use crate::util::{log2_ceil, transpose}; -use std::mem::swap; pub(crate) fn prove, const D: usize>( prover_data: &ProverOnlyCircuitData, @@ -99,7 +100,8 @@ pub(crate) fn prove, const D: usize>( ); // Z is expected at the front of our batch; see `zs_range` and `partial_products_range`. - let plonk_z_vecs = partial_products_and_zs.iter_mut() + let plonk_z_vecs = partial_products_and_zs + .iter_mut() .map(|partial_products_and_z| partial_products_and_z.pop().unwrap()) .collect(); let zs_partial_products = [plonk_z_vecs, partial_products_and_zs.concat()].concat(); @@ -267,7 +269,8 @@ fn wires_permutation_partial_products_and_zs, const let mut z_x = F::ONE; let mut all_partial_products_and_zs = Vec::new(); for quotient_chunk_products in all_quotient_chunk_products { - let mut partial_products_and_z_gx = partial_products_and_z_gx(z_x, "ient_chunk_products); + let mut partial_products_and_z_gx = + partial_products_and_z_gx(z_x, "ient_chunk_products); // The last term is Z(gx), but we replace it with Z(x), otherwise Z would end up shifted. swap(&mut z_x, &mut partial_products_and_z_gx[num_prods]); all_partial_products_and_zs.push(partial_products_and_z_gx); diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index f91e027b..2be91b40 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -76,7 +76,10 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( vanishing_partial_products_terms.extend(partial_product_checks); let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); - let final_deno_product = denominator_values[final_num_prod..].iter().copied().product(); + let final_deno_product = denominator_values[final_num_prod..] + .iter() + .copied() + .product(); let last_partial = *current_partial_products.last().unwrap(); let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; vanishing_v_shift_terms.push(v_shift_term); @@ -183,7 +186,10 @@ pub(crate) fn eval_vanishing_poly_base_batch, const vanishing_partial_products_terms.extend(partial_product_checks); let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); - let final_deno_product = denominator_values[final_num_prod..].iter().copied().product(); + let final_deno_product = denominator_values[final_num_prod..] + .iter() + .copied() + .product(); let last_partial = *current_partial_products.last().unwrap(); let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; vanishing_v_shift_terms.push(v_shift_term); @@ -379,7 +385,8 @@ pub(crate) fn eval_vanishing_poly_recursively, cons let final_deno_product = builder.mul_many_extension(&denominator_values[final_num_prod..]); let z_gz_denominators = builder.mul_extension(z_gz, final_deno_product); let last_partial = *current_partial_products.last().unwrap(); - let v_shift_term = builder.mul_sub_extension(last_partial, final_nume_product, z_gz_denominators); + let v_shift_term = + builder.mul_sub_extension(last_partial, final_nume_product, z_gz_denominators); vanishing_v_shift_terms.push(v_shift_term); } diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 1e7a4f4b..c4133b4d 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -1,8 +1,9 @@ +use itertools::Itertools; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::plonk::circuit_builder::CircuitBuilder; -use itertools::Itertools; pub(crate) fn quotient_chunk_products( quotient_values: &[F], @@ -11,7 +12,8 @@ pub(crate) fn quotient_chunk_products( debug_assert!(max_degree > 1); assert!(quotient_values.len() > 0); let chunk_size = max_degree; - quotient_values.chunks(chunk_size) + quotient_values + .chunks(chunk_size) .map(|chunk| chunk.iter().copied().product()) .collect() } From 72ef58c19d4b2c537fa9ff23008b5e16b1ea62b9 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 12 Nov 2021 18:24:08 +0100 Subject: [PATCH 087/202] Add ReducingExtGate --- src/gates/mod.rs | 1 + src/gates/reducing_extension.rs | 222 ++++++++++++++++++++++++++++++++ src/util/reducing.rs | 55 ++++++-- 3 files changed, 269 insertions(+), 9 deletions(-) create mode 100644 src/gates/reducing_extension.rs diff --git a/src/gates/mod.rs b/src/gates/mod.rs index b1a6028e..da301c62 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -19,6 +19,7 @@ pub(crate) mod poseidon_mds; pub(crate) mod public_input; pub mod random_access; pub mod reducing; +pub mod reducing_extension; pub mod subtraction_u32; pub mod switch; diff --git a/src/gates/reducing_extension.rs b/src/gates/reducing_extension.rs new file mode 100644 index 00000000..93c981a6 --- /dev/null +++ b/src/gates/reducing_extension.rs @@ -0,0 +1,222 @@ +use std::ops::Range; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::extension_field::FieldExtension; +use crate::field::field_types::RichField; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartialWitness, PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// Computes `sum alpha^i c_i` for a vector `c_i` of `num_coeffs` elements of the extension field. +#[derive(Debug, Clone)] +pub struct ReducingExtGate { + pub num_coeffs: usize, +} + +impl ReducingExtGate { + pub fn new(num_coeffs: usize) -> Self { + Self { num_coeffs } + } + + pub fn max_coeffs_len(num_wires: usize, num_routed_wires: usize) -> usize { + ((num_routed_wires - 3 * D) / D).min((num_wires - 2 * D) / (D * 2)) + } + + pub fn wires_output() -> Range { + 0..D + } + pub fn wires_alpha() -> Range { + D..2 * D + } + pub fn wires_old_acc() -> Range { + 2 * D..3 * D + } + const START_COEFFS: usize = 3 * D; + pub fn wires_coeff(i: usize) -> Range { + Self::START_COEFFS + i * D..Self::START_COEFFS + (i + 1) * D + } + fn start_accs(&self) -> usize { + Self::START_COEFFS + self.num_coeffs * D + } + fn wires_accs(&self, i: usize) -> Range { + if i == self.num_coeffs - 1 { + // The last accumulator is the output. + return Self::wires_output(); + } + self.start_accs() + D * i..self.start_accs() + D * (i + 1) + } +} + +impl, const D: usize> Gate for ReducingExtGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let alpha = vars.get_local_ext_algebra(Self::wires_alpha()); + let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc()); + let coeffs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i))) + .collect::>(); + let accs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(self.wires_accs(i))) + .collect::>(); + + let mut constraints = Vec::with_capacity(>::num_constraints(self)); + let mut acc = old_acc; + for i in 0..self.num_coeffs { + constraints.push(acc * alpha + coeffs[i] - accs[i]); + acc = accs[i]; + } + + constraints + .into_iter() + .flat_map(|alg| alg.to_basefield_array()) + .collect() + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let alpha = vars.get_local_ext(Self::wires_alpha()); + let old_acc = vars.get_local_ext(Self::wires_old_acc()); + let coeffs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext(Self::wires_coeff(i))) + .collect::>(); + let accs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext(self.wires_accs(i))) + .collect::>(); + + let mut constraints = Vec::with_capacity(>::num_constraints(self)); + let mut acc = old_acc; + for i in 0..self.num_coeffs { + constraints.extend((acc * alpha + coeffs[i] - accs[i]).to_basefield_array()); + acc = accs[i]; + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let alpha = vars.get_local_ext_algebra(Self::wires_alpha()); + let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc()); + let coeffs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i))) + .collect::>(); + let accs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(self.wires_accs(i))) + .collect::>(); + + let mut constraints = Vec::with_capacity(>::num_constraints(self)); + let mut acc = old_acc; + for i in 0..self.num_coeffs { + let coeff = coeffs[i]; + let mut tmp = builder.mul_add_ext_algebra(acc, alpha, coeff); + tmp = builder.sub_ext_algebra(tmp, accs[i]); + constraints.push(tmp); + acc = accs[i]; + } + + constraints + .into_iter() + .flat_map(|alg| alg.to_ext_target_array()) + .collect() + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + vec![Box::new( + ReducingGenerator { + gate_index, + gate: self.clone(), + } + .adapter(), + )] + } + + fn num_wires(&self) -> usize { + 2 * D + 2 * D * self.num_coeffs + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 2 + } + + fn num_constraints(&self) -> usize { + D * self.num_coeffs + } +} + +#[derive(Debug)] +struct ReducingGenerator { + gate_index: usize, + gate: ReducingExtGate, +} + +impl, const D: usize> SimpleGenerator for ReducingGenerator { + fn dependencies(&self) -> Vec { + ReducingExtGate::::wires_alpha() + .chain(ReducingExtGate::::wires_old_acc()) + .chain((0..self.gate.num_coeffs).flat_map(|i| ReducingExtGate::::wires_coeff(i))) + .map(|i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) + }; + + let alpha = extract_extension(ReducingExtGate::::wires_alpha()); + let old_acc = extract_extension(ReducingExtGate::::wires_old_acc()); + let coeffs = (0..self.gate.num_coeffs) + .map(|i| extract_extension(ReducingExtGate::::wires_coeff(i))) + .collect::>(); + let accs = (0..self.gate.num_coeffs) + .map(|i| ExtensionTarget::from_range(self.gate_index, self.gate.wires_accs(i))) + .collect::>(); + let output = + ExtensionTarget::from_range(self.gate_index, ReducingExtGate::::wires_output()); + + let mut acc = old_acc; + for i in 0..self.gate.num_coeffs { + let computed_acc = acc * alpha + coeffs[i]; + out_buffer.set_extension_target(accs[i], computed_acc); + acc = computed_acc; + } + out_buffer.set_extension_target(output, acc); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::reducing_extension::ReducingExtGate; + + #[test] + fn low_degree() { + test_low_degree::(ReducingExtGate::new(22)); + } + + #[test] + fn eval_fns() -> Result<()> { + test_eval_fns::(ReducingExtGate::new(22)) + } +} diff --git a/src/util/reducing.rs b/src/util/reducing.rs index 3e00602c..ab3e2771 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::gates::reducing::ReducingGate; +use crate::gates::reducing_extension::ReducingExtGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::polynomial::polynomial::PolynomialCoeffs; @@ -145,16 +146,52 @@ impl ReducingFactorTarget { where F: RichField + Extendable, { - let l = terms.len(); - self.count += l as u64; - - let mut terms_vec = terms.to_vec(); - let mut acc = builder.zero_extension(); - terms_vec.reverse(); - - for x in terms_vec { - acc = builder.mul_add_extension(self.base, acc, x); + // let l = terms.len(); + // self.count += l as u64; + // + // let mut terms_vec = terms.to_vec(); + // let mut acc = builder.zero_extension(); + // terms_vec.reverse(); + // + // for x in terms_vec { + // acc = builder.mul_add_extension(self.base, acc, x); + // } + // acc + let max_coeffs_len = ReducingExtGate::::max_coeffs_len( + builder.config.num_wires, + builder.config.num_routed_wires, + ); + self.count += terms.len() as u64; + let zero = builder.zero(); + let zero_ext = builder.zero_extension(); + let mut acc = zero_ext; + let mut reversed_terms = terms.to_vec(); + while reversed_terms.len() % max_coeffs_len != 0 { + reversed_terms.push(zero_ext); } + reversed_terms.reverse(); + for chunk in reversed_terms.chunks_exact(max_coeffs_len) { + let gate = ReducingExtGate::new(max_coeffs_len); + let gate_index = builder.add_gate(gate.clone(), Vec::new()); + + builder.connect_extension( + self.base, + ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_alpha()), + ); + builder.connect_extension( + acc, + ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_old_acc()), + ); + for (i, &t) in chunk.iter().enumerate() { + builder.connect_extension( + t, + ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_coeff(i)), + ); + } + + acc = ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_output()); + } + acc } From 857b74bac538e5bb5d28d69a5c3f396c1a2fcf05 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Fri, 12 Nov 2021 09:48:27 -0800 Subject: [PATCH 088/202] Bring back the base field arithmetic gate (#343) * Bring back the base field arithmetic gate * fix --- src/bin/bench_recursion.rs | 1 + src/gadgets/arithmetic.rs | 136 +++++++++-- src/gadgets/arithmetic_extension.rs | 10 +- src/gates/arithmetic_base.rs | 212 ++++++++++++++++++ ...{arithmetic.rs => arithmetic_extension.rs} | 5 +- src/gates/gate_tree.rs | 2 +- src/gates/mod.rs | 3 +- src/iop/generator.rs | 2 +- src/plonk/circuit_builder.rs | 97 +++++--- src/plonk/circuit_data.rs | 4 + 10 files changed, 420 insertions(+), 52 deletions(-) create mode 100644 src/gates/arithmetic_base.rs rename src/gates/{arithmetic.rs => arithmetic_extension.rs} (96%) diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index e9fc25a4..fc1ad37e 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -26,6 +26,7 @@ fn bench_prove, const D: usize>() -> Result<()> { num_wires: 126, num_routed_wires: 33, constant_gate_size: 6, + use_base_arithmetic_gate: false, security_bits: 128, rate_bits: 3, num_challenges: 3, diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 103857c5..3fe90019 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,8 +1,8 @@ use std::borrow::Borrow; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; -use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::field::field_types::{PrimeField, RichField}; +use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::exponentiation::ExponentiationGate; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -33,18 +33,117 @@ impl, const D: usize> CircuitBuilder { multiplicand_1: Target, addend: Target, ) -> Target { - let multiplicand_0_ext = self.convert_to_ext(multiplicand_0); - let multiplicand_1_ext = self.convert_to_ext(multiplicand_1); - let addend_ext = self.convert_to_ext(addend); + // If we're not configured to use the base arithmetic gate, just call arithmetic_extension. + if !self.config.use_base_arithmetic_gate { + let multiplicand_0_ext = self.convert_to_ext(multiplicand_0); + let multiplicand_1_ext = self.convert_to_ext(multiplicand_1); + let addend_ext = self.convert_to_ext(addend); - self.arithmetic_extension( + return self + .arithmetic_extension( + const_0, + const_1, + multiplicand_0_ext, + multiplicand_1_ext, + addend_ext, + ) + .0[0]; + } + + // See if we can determine the result without adding an `ArithmeticGate`. + if let Some(result) = + self.arithmetic_special_cases(const_0, const_1, multiplicand_0, multiplicand_1, addend) + { + return result; + } + + // See if we've already computed the same operation. + let operation = BaseArithmeticOperation { const_0, const_1, - multiplicand_0_ext, - multiplicand_1_ext, - addend_ext, - ) - .0[0] + multiplicand_0, + multiplicand_1, + addend, + }; + if let Some(&result) = self.base_arithmetic_results.get(&operation) { + return result; + } + + // Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot. + let result = self.add_base_arithmetic_operation(operation); + self.base_arithmetic_results.insert(operation, result); + result + } + + fn add_base_arithmetic_operation(&mut self, operation: BaseArithmeticOperation) -> Target { + let (gate, i) = self.find_base_arithmetic_gate(operation.const_0, operation.const_1); + let wires_multiplicand_0 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_0(i)); + let wires_multiplicand_1 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_1(i)); + let wires_addend = Target::wire(gate, ArithmeticGate::wire_ith_addend(i)); + + self.connect(operation.multiplicand_0, wires_multiplicand_0); + self.connect(operation.multiplicand_1, wires_multiplicand_1); + self.connect(operation.addend, wires_addend); + + Target::wire(gate, ArithmeticGate::wire_ith_output(i)) + } + + /// Checks for special cases where the value of + /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` + /// can be determined without adding an `ArithmeticGate`. + fn arithmetic_special_cases( + &mut self, + const_0: F, + const_1: F, + multiplicand_0: Target, + multiplicand_1: Target, + addend: Target, + ) -> Option { + let zero = self.zero(); + + let mul_0_const = self.target_as_constant(multiplicand_0); + let mul_1_const = self.target_as_constant(multiplicand_1); + let addend_const = self.target_as_constant(addend); + + let first_term_zero = + const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero; + let second_term_zero = const_1 == F::ZERO || addend == zero; + + // If both terms are constant, return their (constant) sum. + let first_term_const = if first_term_zero { + Some(F::ZERO) + } else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) { + Some(x * y * const_0) + } else { + None + }; + let second_term_const = if second_term_zero { + Some(F::ZERO) + } else { + addend_const.map(|x| x * const_1) + }; + if let (Some(x), Some(y)) = (first_term_const, second_term_const) { + return Some(self.constant(x + y)); + } + + if first_term_zero && const_1.is_one() { + return Some(addend); + } + + if second_term_zero { + if let Some(x) = mul_0_const { + if (x * const_0).is_one() { + return Some(multiplicand_1); + } + } + if let Some(x) = mul_1_const { + if (x * const_0).is_one() { + return Some(multiplicand_0); + } + } + } + + None } /// Computes `x * y + z`. @@ -116,7 +215,7 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of `2^power_log`. pub fn exp_power_of_2(&mut self, base: Target, power_log: usize) -> Target { - if power_log > ArithmeticExtensionGate::::new_from_config(&self.config).num_ops { + if power_log > ArithmeticGate::new_from_config(&self.config).num_ops { // Cheaper to just use `ExponentiateGate`. return self.exp_u64(base, 1 << power_log); } @@ -170,8 +269,7 @@ impl, const D: usize> CircuitBuilder { let base_t = self.constant(base); let exponent_bits: Vec<_> = exponent_bits.into_iter().map(|b| *b.borrow()).collect(); - if exponent_bits.len() > ArithmeticExtensionGate::::new_from_config(&self.config).num_ops - { + if exponent_bits.len() > ArithmeticGate::new_from_config(&self.config).num_ops { // Cheaper to just use `ExponentiateGate`. return self.exp_from_bits(base_t, exponent_bits); } @@ -221,3 +319,13 @@ impl, const D: usize> CircuitBuilder { self.inverse_extension(x_ext).0[0] } } + +/// Represents a base arithmetic operation in the circuit. Used to memoize results. +#[derive(Copy, Clone, Eq, PartialEq, Hash)] +pub(crate) struct BaseArithmeticOperation { + const_0: F, + const_1: F, + multiplicand_0: Target, + multiplicand_1: Target, + addend: Target, +} diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index e2654dcc..9fbffad3 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -4,7 +4,7 @@ use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTar use crate::field::extension_field::FieldExtension; use crate::field::extension_field::{Extendable, OEF}; use crate::field::field_types::{Field, PrimeField, RichField}; -use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; @@ -32,7 +32,7 @@ impl, const D: usize> CircuitBuilder { } // See if we've already computed the same operation. - let operation = ArithmeticOperation { + let operation = ExtensionArithmeticOperation { const_0, const_1, multiplicand_0, @@ -51,7 +51,7 @@ impl, const D: usize> CircuitBuilder { fn add_arithmetic_extension_operation( &mut self, - operation: ArithmeticOperation, + operation: ExtensionArithmeticOperation, ) -> ExtensionTarget { let (gate, i) = self.find_arithmetic_gate(operation.const_0, operation.const_1); let wires_multiplicand_0 = ExtensionTarget::from_range( @@ -519,9 +519,9 @@ impl, const D: usize> CircuitBuilder { } } -/// Represents an arithmetic operation in the circuit. Used to memoize results. +/// Represents an extension arithmetic operation in the circuit. Used to memoize results. #[derive(Copy, Clone, Eq, PartialEq, Hash)] -pub(crate) struct ArithmeticOperation, const D: usize> { +pub(crate) struct ExtensionArithmeticOperation, const D: usize> { const_0: F, const_1: F, multiplicand_0: ExtensionTarget, diff --git a/src/gates/arithmetic_base.rs b/src/gates/arithmetic_base.rs new file mode 100644 index 00000000..d5c131a5 --- /dev/null +++ b/src/gates/arithmetic_base.rs @@ -0,0 +1,212 @@ +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config +/// supports enough routed wires, it can support several such operations in one gate. +#[derive(Debug)] +pub struct ArithmeticGate { + /// Number of arithmetic operations performed by an arithmetic gate. + pub num_ops: usize, +} + +impl ArithmeticGate { + pub fn new_from_config(config: &CircuitConfig) -> Self { + Self { + num_ops: Self::num_ops(config), + } + } + + /// Determine the maximum number of operations that can fit in one gate for the given config. + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 4; + config.num_routed_wires / wires_per_op + } + + pub fn wire_ith_multiplicand_0(i: usize) -> usize { + 4 * i + } + pub fn wire_ith_multiplicand_1(i: usize) -> usize { + 4 * i + 1 + } + pub fn wire_ith_addend(i: usize) -> usize { + 4 * i + 2 + } + pub fn wire_ith_output(i: usize) -> usize { + 4 * i + 3 + } +} + +impl, const D: usize> Gate for ArithmeticGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[Self::wire_ith_addend(i)]; + let output = vars.local_wires[Self::wire_ith_output(i)]; + let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1; + + constraints.push(output - computed_output); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[Self::wire_ith_addend(i)]; + let output = vars.local_wires[Self::wire_ith_output(i)]; + let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1; + + constraints.push(output - computed_output); + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[Self::wire_ith_addend(i)]; + let output = vars.local_wires[Self::wire_ith_output(i)]; + let computed_output = { + let scaled_mul = + builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]); + builder.mul_add_extension(const_1, addend, scaled_mul) + }; + + let diff = builder.sub_extension(output, computed_output); + constraints.push(diff); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + local_constants: &[F], + ) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + ArithmeticBaseGenerator { + gate_index, + const_0: local_constants[0], + const_1: local_constants[1], + i, + } + .adapter(), + ); + g + }) + .collect::>() + } + + fn num_wires(&self) -> usize { + self.num_ops * 4 + } + + fn num_constants(&self) -> usize { + 2 + } + + fn degree(&self) -> usize { + 3 + } + + fn num_constraints(&self) -> usize { + self.num_ops + } +} + +#[derive(Clone, Debug)] +struct ArithmeticBaseGenerator, const D: usize> { + gate_index: usize, + const_0: F, + const_1: F, + i: usize, +} + +impl, const D: usize> SimpleGenerator + for ArithmeticBaseGenerator +{ + fn dependencies(&self) -> Vec { + [ + ArithmeticGate::wire_ith_multiplicand_0(self.i), + ArithmeticGate::wire_ith_multiplicand_1(self.i), + ArithmeticGate::wire_ith_addend(self.i), + ] + .iter() + .map(|&i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let get_wire = + |wire: usize| -> F { witness.get_target(Target::wire(self.gate_index, wire)) }; + + let multiplicand_0 = get_wire(ArithmeticGate::wire_ith_multiplicand_0(self.i)); + let multiplicand_1 = get_wire(ArithmeticGate::wire_ith_multiplicand_1(self.i)); + let addend = get_wire(ArithmeticGate::wire_ith_addend(self.i)); + + let output_target = Target::wire(self.gate_index, ArithmeticGate::wire_ith_output(self.i)); + + let computed_output = + multiplicand_0 * multiplicand_1 * self.const_0 + addend * self.const_1; + + out_buffer.set_target(output_target, computed_output) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::arithmetic_base::ArithmeticGate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::plonk::circuit_data::CircuitConfig; + + #[test] + fn low_degree() { + let gate = ArithmeticGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_low_degree::(gate); + } + + #[test] + fn eval_fns() -> Result<()> { + let gate = ArithmeticGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_eval_fns::(gate) + } +} diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic_extension.rs similarity index 96% rename from src/gates/arithmetic.rs rename to src/gates/arithmetic_extension.rs index 95b48e2f..dbde7535 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic_extension.rs @@ -12,7 +12,8 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. +/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config +/// supports enough routed wires, it can support several such operations in one gate. #[derive(Debug)] pub struct ArithmeticExtensionGate { /// Number of arithmetic operations performed by an arithmetic gate. @@ -206,7 +207,7 @@ mod tests { use anyhow::Result; use crate::field::goldilocks_field::GoldilocksField; - use crate::gates::arithmetic::ArithmeticExtensionGate; + use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::plonk::circuit_data::CircuitConfig; diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index ed9a73ac..aaba41c7 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -223,7 +223,7 @@ impl, const D: usize> Tree> { mod tests { use super::*; use crate::field::goldilocks_field::GoldilocksField; - use crate::gates::arithmetic::ArithmeticExtensionGate; + use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::base_sum::BaseSumGate; use crate::gates::constant::ConstantGate; use crate::gates::gmimc::GMiMCGate; diff --git a/src/gates/mod.rs b/src/gates/mod.rs index b1a6028e..7f7ee32b 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -1,7 +1,8 @@ // Gates have `new` methods that return `GateRef`s. #![allow(clippy::new_ret_no_self)] -pub mod arithmetic; +pub mod arithmetic_base; +pub mod arithmetic_extension; pub mod arithmetic_u32; pub mod assert_le; pub mod base_sum; diff --git a/src/iop/generator.rs b/src/iop/generator.rs index 8c6cb294..ae973d7c 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -87,7 +87,7 @@ pub(crate) fn generate_partial_witness<'a, F: RichField + Extendable, const D assert_eq!( remaining_generators, 0, "{} generators weren't run", - remaining_generators + remaining_generators, ); witness diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index c6c49826..aac9d42e 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -12,9 +12,11 @@ use crate::field::fft::fft_root_table; use crate::field::field_types::{Field, RichField}; use crate::fri::commitment::PolynomialBatchCommitment; use crate::fri::{FriConfig, FriParams}; -use crate::gadgets::arithmetic_extension::ArithmeticOperation; +use crate::gadgets::arithmetic::BaseArithmeticOperation; +use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation; use crate::gadgets::arithmetic_u32::U32Target; -use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::gates::arithmetic_base::ArithmeticGate; +use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; @@ -74,8 +76,11 @@ pub struct CircuitBuilder, const D: usize> { constants_to_targets: HashMap, targets_to_constants: HashMap, + /// Memoized results of `arithmetic` calls. + pub(crate) base_arithmetic_results: HashMap, Target>, + /// Memoized results of `arithmetic_extension` calls. - pub(crate) arithmetic_results: HashMap, ExtensionTarget>, + pub(crate) arithmetic_results: HashMap, ExtensionTarget>, batched_gates: BatchedGates, } @@ -93,6 +98,7 @@ impl, const D: usize> CircuitBuilder { marked_targets: Vec::new(), generators: Vec::new(), constants_to_targets: HashMap::new(), + base_arithmetic_results: HashMap::new(), arithmetic_results: HashMap::new(), targets_to_constants: HashMap::new(), batched_gates: BatchedGates::new(), @@ -742,11 +748,13 @@ impl, const D: usize> CircuitBuilder { } } -/// Various gate types can contain multiple copies in a single Gate. This helper struct lets a CircuitBuilder track such gates that are currently being "filled up." +/// Various gate types can contain multiple copies in a single Gate. This helper struct lets a +/// CircuitBuilder track such gates that are currently being "filled up." pub struct BatchedGates, const D: usize> { /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using /// these constants with gate index `g` and already using `i` arithmetic operations. pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, + pub(crate) free_base_arithmetic: HashMap<(F, F), (usize, usize)>, /// A map `(c0, c1) -> (g, i)` from constants `vec_size` to an available arithmetic gate using /// these constants with gate index `g` and already using `i` random accesses. @@ -771,6 +779,7 @@ impl, const D: usize> BatchedGates { pub fn new() -> Self { Self { free_arithmetic: HashMap::new(), + free_base_arithmetic: HashMap::new(), free_random_access: HashMap::new(), current_switch_gates: Vec::new(), current_u32_arithmetic_gate: None, @@ -781,6 +790,37 @@ impl, const D: usize> BatchedGates { } impl, const D: usize> CircuitBuilder { + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub(crate) fn find_base_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_base_arithmetic + .get(&(const_0, const_1)) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + ArithmeticGate::new_from_config(&self.config), + vec![const_0, const_1], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < ArithmeticGate::num_ops(&self.config) - 1 { + self.batched_gates + .free_base_arithmetic + .insert((const_0, const_1), (gate, i + 1)); + } else { + self.batched_gates + .free_base_arithmetic + .remove(&(const_0, const_1)); + } + + (gate, i) + } + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index /// `g` and the gate's `i`-th operation is available. @@ -941,36 +981,36 @@ impl, const D: usize> CircuitBuilder { (gate, instance) } + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticGate` are run. + fn fill_base_arithmetic_gates(&mut self) { + let zero = self.zero(); + for ((c0, c1), (_gate, i)) in self.batched_gates.free_base_arithmetic.clone() { + for _ in i..ArithmeticGate::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_target(); + self.arithmetic(c0, c1, dummy, dummy, dummy); + self.connect(dummy, zero); + } + } + assert!(self.batched_gates.free_base_arithmetic.is_empty()); + } + /// Fill the remaining unused arithmetic operations with zeros, so that all /// `ArithmeticExtensionGenerator`s are run. fn fill_arithmetic_gates(&mut self) { let zero = self.zero_extension(); - let remaining_arithmetic_gates = self - .batched_gates - .free_arithmetic - .values() - .copied() - .collect::>(); - for (gate, i) in remaining_arithmetic_gates { - for j in i..ArithmeticExtensionGate::::num_ops(&self.config) { - let wires_multiplicand_0 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_multiplicand_0(j), - ); - let wires_multiplicand_1 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_multiplicand_1(j), - ); - let wires_addend = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_addend(j), - ); - - self.connect_extension(zero, wires_multiplicand_0); - self.connect_extension(zero, wires_multiplicand_1); - self.connect_extension(zero, wires_addend); + for ((c0, c1), (_gate, i)) in self.batched_gates.free_arithmetic.clone() { + for _ in i..ArithmeticExtensionGate::::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_extension_target(); + self.arithmetic_extension(c0, c1, dummy, dummy, dummy); + self.connect_extension(dummy, zero); } } + assert!(self.batched_gates.free_arithmetic.is_empty()); } /// Fill the remaining unused random access operations with zeros, so that all @@ -1064,6 +1104,7 @@ impl, const D: usize> CircuitBuilder { fn fill_batched_gates(&mut self) { self.fill_arithmetic_gates(); + self.fill_base_arithmetic_gates(); self.fill_random_access_gates(); self.fill_switch_gates(); self.fill_u32_arithmetic_gates(); diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 869543af..d54d327d 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -26,6 +26,9 @@ pub struct CircuitConfig { pub num_wires: usize, pub num_routed_wires: usize, pub constant_gate_size: usize, + /// Whether to use a dedicated gate for base field arithmetic, rather than using a single gate + /// for both base field and extension field arithmetic. + pub use_base_arithmetic_gate: bool, pub security_bits: usize, pub rate_bits: usize, /// The number of challenge points to generate, for IOPs that have soundness errors of (roughly) @@ -55,6 +58,7 @@ impl CircuitConfig { num_wires: 143, num_routed_wires: 25, constant_gate_size: 6, + use_base_arithmetic_gate: true, security_bits: 100, rate_bits: 3, num_challenges: 2, From 4a5123de81184ec7a8fdd52719a3627e05a93349 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 12 Nov 2021 12:12:58 -0800 Subject: [PATCH 089/202] reduced test sizes --- src/gadgets/multiple_comparison.rs | 4 ++-- src/gates/assert_le.rs | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index 77e660e6..3a5f2421 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -127,8 +127,8 @@ mod tests { #[test] fn test_multiple_comparison() -> Result<()> { - for size in [1, 3, 6, 10] { - for num_bits in [20, 32, 40, 50] { + for size in [1, 3, 6] { + for num_bits in [20, 32, 40, 44] { test_list_le(size, num_bits).unwrap(); } } diff --git a/src/gates/assert_le.rs b/src/gates/assert_le.rs index 4d33a867..ffbc043a 100644 --- a/src/gates/assert_le.rs +++ b/src/gates/assert_le.rs @@ -470,8 +470,8 @@ mod tests { #[test] fn low_degree() { - let num_bits = 40; - let num_chunks = 5; + let num_bits = 20; + let num_chunks = 4; test_low_degree::(AssertLessThanGate::<_, 4>::new( num_bits, num_chunks, @@ -480,8 +480,8 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - let num_bits = 40; - let num_chunks = 5; + let num_bits = 20; + let num_chunks = 4; test_eval_fns::(AssertLessThanGate::<_, 4>::new( num_bits, num_chunks, From 26a222bbdf63f6d481ee7830d14b01c4fe9d9e1e Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sun, 14 Nov 2021 11:57:36 -0800 Subject: [PATCH 090/202] Fewer wires in `PoseidonGate` (#356) Closes #345. --- src/gates/poseidon.rs | 213 ++++++++++++++++++++++++-------------- src/plonk/circuit_data.rs | 2 +- 2 files changed, 135 insertions(+), 80 deletions(-) diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 6e1eb69a..59c23b44 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -56,35 +56,49 @@ where /// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0. pub const WIRE_SWAP: usize = 2 * WIDTH; + const START_DELTA: usize = 2 * WIDTH + 1; + + /// A wire which stores `swap * (input[i + 4] - input[i])`; used to compute the swapped inputs. + fn wire_delta(i: usize) -> usize { + assert!(i < 4); + Self::START_DELTA + i + } + + const START_FULL_0: usize = Self::START_DELTA + 4; + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set /// of full rounds. fn wire_full_sbox_0(round: usize, i: usize) -> usize { + debug_assert!( + round != 0, + "First round S-box inputs are not stored as wires" + ); debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); debug_assert!(i < WIDTH); - 2 * WIDTH + 1 + WIDTH * round + i + Self::START_FULL_0 + WIDTH * (round - 1) + i } + const START_PARTIAL: usize = Self::START_FULL_0 + WIDTH * (poseidon::HALF_N_FULL_ROUNDS - 1); + /// A wire which stores the input of the S-box of the `round`-th round of the partial rounds. fn wire_partial_sbox(round: usize) -> usize { debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); - 2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round + Self::START_PARTIAL + round } + const START_FULL_1: usize = Self::START_PARTIAL + poseidon::N_PARTIAL_ROUNDS; + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set /// of full rounds. fn wire_full_sbox_1(round: usize, i: usize) -> usize { debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); debug_assert!(i < WIDTH); - 2 * WIDTH - + 1 - + WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round) - + poseidon::N_PARTIAL_ROUNDS - + i + Self::START_FULL_1 + WIDTH * round + i } /// End of wire indices, exclusive. fn end() -> usize { - 2 * WIDTH + 1 + WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + Self::START_FULL_1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS } } @@ -104,31 +118,38 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * (swap - F::Extension::ONE)); - let mut state = Vec::with_capacity(WIDTH); + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - state.push(a + swap * (b - a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - state.push(a + swap * (b - a)); - } - for i in 8..WIDTH { - state.push(vars.local_wires[i]); + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + constraints.push(swap * (input_rhs - input_lhs) - delta_i); + } + + // Compute the possibly-swapped input layer. + let mut state = [F::Extension::ZERO; WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = Self::wire_input(i); + let input_rhs = Self::wire_input(i + 4); + state[i] = vars.local_wires[input_lhs] + delta_i; + state[i + 4] = vars.local_wires[input_rhs] - delta_i; + } + for i in 8..WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; } - let mut state: [F::Extension; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer_field(&mut state, round_ctr); - for i in 0..WIDTH { - let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(state[i] - sbox_in); - state[i] = sbox_in; + if r != 0 { + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } } >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); @@ -183,31 +204,38 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * swap.sub_one()); - let mut state = Vec::with_capacity(WIDTH); + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - state.push(a + swap * (b - a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - state.push(a + swap * (b - a)); - } - for i in 8..WIDTH { - state.push(vars.local_wires[i]); + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + constraints.push(swap * (input_rhs - input_lhs) - delta_i); + } + + // Compute the possibly-swapped input layer. + let mut state = [F::ZERO; WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = Self::wire_input(i); + let input_rhs = Self::wire_input(i + 4); + state[i] = vars.local_wires[input_lhs] + delta_i; + state[i + 4] = vars.local_wires[input_rhs] - delta_i; + } + for i in 8..WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; } - let mut state: [F; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer(&mut state, round_ctr); - for i in 0..WIDTH { - let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(state[i] - sbox_in); - state[i] = sbox_in; + if r != 0 { + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } } >::sbox_layer(&mut state); state = >::mds_layer(&state); @@ -267,38 +295,39 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(builder.mul_sub_extension(swap, swap, swap)); - let mut state = Vec::with_capacity(WIDTH); - // We need to compute both `if swap {b} else {a}` and `if swap {a} else {b}`. - // We will arithmetize them as - // swap (b - a) + a - // -swap (b - a) + b - // so that `b - a` can be used for both. - let mut state_first_4 = vec![]; - let mut state_next_4 = vec![]; + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - let delta = builder.sub_extension(b, a); - state_first_4.push(builder.mul_add_extension(swap, delta, a)); - state_next_4.push(builder.arithmetic_extension(F::NEG_ONE, F::ONE, swap, delta, b)); + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let diff = builder.sub_extension(input_rhs, input_lhs); + constraints.push(builder.mul_sub_extension(swap, diff, delta_i)); } - state.extend(state_first_4); - state.extend(state_next_4); + // Compute the possibly-swapped input layer. + let mut state = [builder.zero_extension(); WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + state[i] = builder.add_extension(input_lhs, delta_i); + state[i + 4] = builder.sub_extension(input_rhs, delta_i); + } for i in 8..WIDTH { - state.push(vars.local_wires[i]); + state[i] = vars.local_wires[Self::wire_input(i)]; } - let mut state: [ExtensionTarget; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer_recursive(builder, &mut state, round_ctr); - for i in 0..WIDTH { - let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(builder.sub_extension(state[i], sbox_in)); - state[i] = sbox_in; + if r != 0 { + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(builder.sub_extension(state[i], sbox_in)); + state[i] = sbox_in; + } } >::sbox_layer_recursive(builder, &mut state); state = >::mds_layer_recursive(builder, &state); @@ -386,7 +415,7 @@ where } fn num_constraints(&self) -> usize { - WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + WIDTH * (poseidon::N_FULL_ROUNDS_TOTAL - 1) + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + 4 } } @@ -422,19 +451,20 @@ where }; let mut state = (0..WIDTH) - .map(|i| { - witness.get_wire(Wire { - gate: self.gate_index, - input: PoseidonGate::::wire_input(i), - }) - }) + .map(|i| witness.get_wire(local_wire(PoseidonGate::::wire_input(i)))) .collect::>(); - let swap_value = witness.get_wire(Wire { - gate: self.gate_index, - input: PoseidonGate::::WIRE_SWAP, - }); + let swap_value = witness.get_wire(local_wire(PoseidonGate::::WIRE_SWAP)); debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); + + for i in 0..4 { + let delta_i = swap_value * (state[i + 4] - state[i]); + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_delta(i)), + delta_i, + ); + } + if swap_value == F::ONE { for i in 0..4 { state.swap(i, 4 + i); @@ -446,11 +476,13 @@ where for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer_field(&mut state, round_ctr); - for i in 0..WIDTH { - out_buffer.set_wire( - local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), - state[i], - ); + if r != 0 { + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), + state[i], + ); + } } >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); @@ -522,6 +554,29 @@ mod tests { use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + #[test] + fn wire_indices() { + type F = GoldilocksField; + const WIDTH: usize = 12; + type Gate = PoseidonGate; + + assert_eq!(Gate::wire_input(0), 0); + assert_eq!(Gate::wire_input(11), 11); + assert_eq!(Gate::wire_output(0), 12); + assert_eq!(Gate::wire_output(11), 23); + assert_eq!(Gate::WIRE_SWAP, 24); + assert_eq!(Gate::wire_delta(0), 25); + assert_eq!(Gate::wire_delta(3), 28); + assert_eq!(Gate::wire_full_sbox_0(1, 0), 29); + assert_eq!(Gate::wire_full_sbox_0(3, 0), 53); + assert_eq!(Gate::wire_full_sbox_0(3, 11), 64); + assert_eq!(Gate::wire_partial_sbox(0), 65); + assert_eq!(Gate::wire_partial_sbox(21), 86); + assert_eq!(Gate::wire_full_sbox_1(0, 0), 87); + assert_eq!(Gate::wire_full_sbox_1(3, 0), 123); + assert_eq!(Gate::wire_full_sbox_1(3, 11), 134); + } + #[test] fn generated_output() { type F = GoldilocksField; diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index d54d327d..564d558d 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -55,7 +55,7 @@ impl CircuitConfig { /// A typical recursion config, without zero-knowledge, targeting ~100 bit security. pub(crate) fn standard_recursion_config() -> Self { Self { - num_wires: 143, + num_wires: 135, num_routed_wires: 25, constant_gate_size: 6, use_base_arithmetic_gate: true, From fe1e67165a574b836731a1cefc7e9d17be6f5d55 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sun, 14 Nov 2021 11:58:14 -0800 Subject: [PATCH 091/202] 256 bit salts (#352) I believe I was mistaken earlier, and hash-based commitments actually call for `r = 2*security_bits` bits of randomness. I.e. I believe breaking a particular commitment requires `O(2^r)` work (more if the committed value adds entropy, but assume it doesn't), but breaking one of `n` commitments requires less work. It seems like this should be a well-known thing, but I can't find much in the literature. The IOP paper does mention using `2*security_bits` of randomness though. --- src/fri/commitment.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fri/commitment.rs b/src/fri/commitment.rs index 8233a293..c8a13cac 100644 --- a/src/fri/commitment.rs +++ b/src/fri/commitment.rs @@ -16,8 +16,8 @@ use crate::util::reducing::ReducingFactor; use crate::util::timing::TimingTree; use crate::util::{log2_strict, reverse_bits, reverse_index_bits_in_place, transpose}; -/// Two (~64 bit) field elements gives ~128 bit security. -pub const SALT_SIZE: usize = 2; +/// Four (~64 bit) field elements gives ~128 bit security. +pub const SALT_SIZE: usize = 4; /// Represents a batch FRI based commitment to a list of polynomials. pub struct PolynomialBatchCommitment { From 7185c2d7d2fd6a7bc3a1eb0ade8f84421435379e Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sun, 14 Nov 2021 11:58:44 -0800 Subject: [PATCH 092/202] Fix & cleanup partial products (#355) My previous change introduced a bug -- when `num_routed_wires` was a multiple of 8, the partial products "consumed" all `num_routed_wires` terms, whereas we actually want to leave 8 terms for the final product. This also changes `check_partial_products` to include the final product constraint, and merges `vanishing_v_shift_terms` into `vanishing_partial_products_terms`. I think this is natural since `Z(x)`, partial products, and `Z(g x)` are all part of the product accumulator chain. --- src/plonk/vanishing_poly.rs | 45 ++------------ src/util/partial_products.rs | 114 ++++++++++++++++++----------------- 2 files changed, 64 insertions(+), 95 deletions(-) diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 2be91b40..ef322c9f 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -28,7 +28,7 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( alphas: &[F], ) -> Vec { let max_degree = common_data.quotient_degree_factor; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let constraint_terms = evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars); @@ -37,8 +37,6 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( let mut vanishing_z_1_terms = Vec::new(); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - // The Z(x) f'(x) - g'(x) Z(g x) terms. - let mut vanishing_v_shift_terms = Vec::new(); let l1_x = plonk_common::eval_l_1(common_data.degree(), x); @@ -71,24 +69,15 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( &denominator_values, current_partial_products, z_x, + z_gz, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); - - let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); - let final_deno_product = denominator_values[final_num_prod..] - .iter() - .copied() - .product(); - let last_partial = *current_partial_products.last().unwrap(); - let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; - vanishing_v_shift_terms.push(v_shift_term); } let vanishing_terms = [ vanishing_z_1_terms, vanishing_partial_products_terms, - vanishing_v_shift_terms, constraint_terms, ] .concat(); @@ -121,7 +110,7 @@ pub(crate) fn eval_vanishing_poly_base_batch, const assert_eq!(s_sigmas_batch.len(), n); let max_degree = common_data.quotient_degree_factor; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let num_gate_constraints = common_data.num_gate_constraints; @@ -139,8 +128,6 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - // The Z(x) f'(x) - g'(x) Z(g x) terms. - let mut vanishing_v_shift_terms = Vec::with_capacity(num_challenges); let mut res_batch: Vec> = Vec::with_capacity(n); for k in 0..n { @@ -181,19 +168,11 @@ pub(crate) fn eval_vanishing_poly_base_batch, const &denominator_values, current_partial_products, z_x, + z_gz, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); - let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); - let final_deno_product = denominator_values[final_num_prod..] - .iter() - .copied() - .product(); - let last_partial = *current_partial_products.last().unwrap(); - let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; - vanishing_v_shift_terms.push(v_shift_term); - numerator_values.clear(); denominator_values.clear(); } @@ -201,14 +180,12 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let vanishing_terms = vanishing_z_1_terms .iter() .chain(vanishing_partial_products_terms.iter()) - .chain(vanishing_v_shift_terms.iter()) .chain(constraint_terms); let res = plonk_common::reduce_with_powers_multi(vanishing_terms, alphas); res_batch.push(res); vanishing_z_1_terms.clear(); vanishing_partial_products_terms.clear(); - vanishing_v_shift_terms.clear(); } res_batch } @@ -314,7 +291,7 @@ pub(crate) fn eval_vanishing_poly_recursively, cons alphas: &[Target], ) -> Vec> { let max_degree = common_data.quotient_degree_factor; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let constraint_terms = with_context!( builder, @@ -331,8 +308,6 @@ pub(crate) fn eval_vanishing_poly_recursively, cons let mut vanishing_z_1_terms = Vec::new(); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - // The Z(x) f'(x) - g'(x) Z(g x) terms. - let mut vanishing_v_shift_terms = Vec::new(); let l1_x = eval_l_1_recursively(builder, common_data.degree(), x, x_pow_deg); @@ -377,23 +352,15 @@ pub(crate) fn eval_vanishing_poly_recursively, cons &denominator_values, current_partial_products, z_x, + z_gz, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); - - let final_nume_product = builder.mul_many_extension(&numerator_values[final_num_prod..]); - let final_deno_product = builder.mul_many_extension(&denominator_values[final_num_prod..]); - let z_gz_denominators = builder.mul_extension(z_gz, final_deno_product); - let last_partial = *current_partial_products.last().unwrap(); - let v_shift_term = - builder.mul_sub_extension(last_partial, final_nume_product, z_gz_denominators); - vanishing_v_shift_terms.push(v_shift_term); } let vanishing_terms = [ vanishing_z_1_terms, vanishing_partial_products_terms, - vanishing_v_shift_terms, constraint_terms, ] .concat(); diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index c4133b4d..0f3c9bfa 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -1,9 +1,12 @@ +use std::iter; + use itertools::Itertools; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::ceil_div_usize; pub(crate) fn quotient_chunk_products( quotient_values: &[F], @@ -33,70 +36,74 @@ pub(crate) fn partial_products_and_z_gx(z_x: F, quotient_chunk_product /// Returns a tuple `(a,b)`, where `a` is the length of the output of `partial_products()` on a /// vector of length `n`, and `b` is the number of original elements consumed in `partial_products()`. -pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { +pub(crate) fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { debug_assert!(max_degree > 1); let chunk_size = max_degree; - let num_chunks = n / chunk_size; - + // We'll split the product into `ceil_div_usize(n, chunk_size)` chunks, but the last chunk will + // be associated with Z(gx) itself. Thus we subtract one to get the chunks associated with + // partial products. + let num_chunks = ceil_div_usize(n, chunk_size) - 1; (num_chunks, num_chunks * chunk_size) } -/// Checks that the partial products of `numerators/denominators` are coherent with those in `partials` by only computing -/// products of size `max_degree` or less. +/// Checks the relationship between each pair of partial product accumulators. In particular, this +/// sequence of accumulators starts with `Z(x)`, then contains each partial product polynomials +/// `p_i(x)`, and finally `Z(g x)`. See the partial products section of the Plonky2 paper. pub(crate) fn check_partial_products( numerators: &[F], denominators: &[F], partials: &[F], z_x: F, + z_gx: F, max_degree: usize, ) -> Vec { debug_assert!(max_degree > 1); - let mut acc = z_x; - let mut partials = partials.iter(); - let mut res = Vec::new(); + let product_accs = iter::once(&z_x) + .chain(partials.iter()) + .chain(iter::once(&z_gx)); let chunk_size = max_degree; - for (nume_chunk, deno_chunk) in numerators - .chunks_exact(chunk_size) - .zip_eq(denominators.chunks_exact(chunk_size)) - { - let num_chunk_product = nume_chunk.iter().copied().product(); - let den_chunk_product = deno_chunk.iter().copied().product(); - let new_acc = *partials.next().unwrap(); - res.push(acc * num_chunk_product - new_acc * den_chunk_product); - acc = new_acc; - } - debug_assert!(partials.next().is_none()); - - res + numerators + .chunks(chunk_size) + .zip_eq(denominators.chunks(chunk_size)) + .zip_eq(product_accs.tuple_windows()) + .map(|((nume_chunk, deno_chunk), (&prev_acc, &next_acc))| { + let num_chunk_product = nume_chunk.iter().copied().product(); + let den_chunk_product = deno_chunk.iter().copied().product(); + // Assert that next_acc * deno_product = prev_acc * nume_product. + prev_acc * num_chunk_product - next_acc * den_chunk_product + }) + .collect() } +/// Checks the relationship between each pair of partial product accumulators. In particular, this +/// sequence of accumulators starts with `Z(x)`, then contains each partial product polynomials +/// `p_i(x)`, and finally `Z(g x)`. See the partial products section of the Plonky2 paper. pub(crate) fn check_partial_products_recursively, const D: usize>( builder: &mut CircuitBuilder, numerators: &[ExtensionTarget], denominators: &[ExtensionTarget], partials: &[ExtensionTarget], - mut acc: ExtensionTarget, + z_x: ExtensionTarget, + z_gx: ExtensionTarget, max_degree: usize, ) -> Vec> { debug_assert!(max_degree > 1); - let mut partials = partials.iter(); - let mut res = Vec::new(); + let product_accs = iter::once(&z_x) + .chain(partials.iter()) + .chain(iter::once(&z_gx)); let chunk_size = max_degree; - for (nume_chunk, deno_chunk) in numerators - .chunks_exact(chunk_size) - .zip(denominators.chunks_exact(chunk_size)) - { - let nume_product = builder.mul_many_extension(nume_chunk); - let deno_product = builder.mul_many_extension(deno_chunk); - let new_acc = *partials.next().unwrap(); - let new_acc_deno = builder.mul_extension(new_acc, deno_product); - // Assert that new_acc*deno_product = acc * nume_product. - res.push(builder.mul_sub_extension(acc, nume_product, new_acc_deno)); - acc = new_acc; - } - debug_assert!(partials.next().is_none()); - - res + numerators + .chunks(chunk_size) + .zip_eq(denominators.chunks(chunk_size)) + .zip_eq(product_accs.tuple_windows()) + .map(|((nume_chunk, deno_chunk), (&prev_acc, &next_acc))| { + let nume_product = builder.mul_many_extension(nume_chunk); + let deno_product = builder.mul_many_extension(deno_chunk); + let next_acc_deno = builder.mul_extension(next_acc, deno_product); + // Assert that next_acc * deno_product = prev_acc * nume_product. + builder.mul_sub_extension(prev_acc, nume_product, next_acc_deno) + }) + .collect() } #[cfg(test)] @@ -108,36 +115,31 @@ mod tests { fn test_partial_products() { type F = GoldilocksField; let denominators = vec![F::ONE; 6]; + let z_x = F::ONE; let v = field_vec(&[1, 2, 3, 4, 5, 6]); + let z_gx = F::from_canonical_u64(720); let quotient_chunks_prods = quotient_chunk_products(&v, 2); assert_eq!(quotient_chunks_prods, field_vec(&[2, 12, 30])); - let p = partial_products_and_z_gx(F::ONE, "ient_chunks_prods); - assert_eq!(p, field_vec(&[2, 24, 720])); + let pps_and_z_gx = partial_products_and_z_gx(z_x, "ient_chunks_prods); + let pps = &pps_and_z_gx[..pps_and_z_gx.len() - 1]; + assert_eq!(pps_and_z_gx, field_vec(&[2, 24, 720])); let nums = num_partial_products(v.len(), 2); - assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &denominators, &p, F::ONE, 2) + assert_eq!(pps.len(), nums.0); + assert!(check_partial_products(&v, &denominators, pps, z_x, z_gx, 2) .iter() .all(|x| x.is_zero())); - assert_eq!( - *p.last().unwrap() * v[nums.1..].iter().copied().product::(), - v.into_iter().product::(), - ); - let v = field_vec(&[1, 2, 3, 4, 5, 6]); let quotient_chunks_prods = quotient_chunk_products(&v, 3); assert_eq!(quotient_chunks_prods, field_vec(&[6, 120])); - let p = partial_products_and_z_gx(F::ONE, "ient_chunks_prods); - assert_eq!(p, field_vec(&[6, 720])); + let pps_and_z_gx = partial_products_and_z_gx(z_x, "ient_chunks_prods); + let pps = &pps_and_z_gx[..pps_and_z_gx.len() - 1]; + assert_eq!(pps_and_z_gx, field_vec(&[6, 720])); let nums = num_partial_products(v.len(), 3); - assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &denominators, &p, F::ONE, 3) + assert_eq!(pps.len(), nums.0); + assert!(check_partial_products(&v, &denominators, pps, z_x, z_gx, 3) .iter() .all(|x| x.is_zero())); - assert_eq!( - *p.last().unwrap() * v[nums.1..].iter().copied().product::(), - v.into_iter().product::(), - ); } fn field_vec(xs: &[usize]) -> Vec { From 66719b0cfc2ef935c539eed293341247e8af5650 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 15 Nov 2021 10:33:27 +0100 Subject: [PATCH 093/202] Remove comments --- src/util/reducing.rs | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/util/reducing.rs b/src/util/reducing.rs index ab3e2771..12be80f6 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -146,17 +146,6 @@ impl ReducingFactorTarget { where F: RichField + Extendable, { - // let l = terms.len(); - // self.count += l as u64; - // - // let mut terms_vec = terms.to_vec(); - // let mut acc = builder.zero_extension(); - // terms_vec.reverse(); - // - // for x in terms_vec { - // acc = builder.mul_add_extension(self.base, acc, x); - // } - // acc let max_coeffs_len = ReducingExtGate::::max_coeffs_len( builder.config.num_wires, builder.config.num_routed_wires, From a54db66f68ef58bd2a7802af05fad2b217db2c02 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 15 Nov 2021 11:38:48 +0100 Subject: [PATCH 094/202] Use arithmetic gate for small reductions --- src/gates/reducing_extension.rs | 34 ++++++++------- src/util/reducing.rs | 77 ++++++++++++++++++++++++++++----- 2 files changed, 85 insertions(+), 26 deletions(-) diff --git a/src/gates/reducing_extension.rs b/src/gates/reducing_extension.rs index 93c981a6..9ee2134f 100644 --- a/src/gates/reducing_extension.rs +++ b/src/gates/reducing_extension.rs @@ -13,11 +13,11 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// Computes `sum alpha^i c_i` for a vector `c_i` of `num_coeffs` elements of the extension field. #[derive(Debug, Clone)] -pub struct ReducingExtGate { +pub struct ReducingExtensionGate { pub num_coeffs: usize, } -impl ReducingExtGate { +impl ReducingExtensionGate { pub fn new(num_coeffs: usize) -> Self { Self { num_coeffs } } @@ -51,7 +51,7 @@ impl ReducingExtGate { } } -impl, const D: usize> Gate for ReducingExtGate { +impl, const D: usize> Gate for ReducingExtensionGate { fn id(&self) -> String { format!("{:?}", self) } @@ -163,14 +163,16 @@ impl, const D: usize> Gate for ReducingExtGat #[derive(Debug)] struct ReducingGenerator { gate_index: usize, - gate: ReducingExtGate, + gate: ReducingExtensionGate, } impl, const D: usize> SimpleGenerator for ReducingGenerator { fn dependencies(&self) -> Vec { - ReducingExtGate::::wires_alpha() - .chain(ReducingExtGate::::wires_old_acc()) - .chain((0..self.gate.num_coeffs).flat_map(|i| ReducingExtGate::::wires_coeff(i))) + ReducingExtensionGate::::wires_alpha() + .chain(ReducingExtensionGate::::wires_old_acc()) + .chain( + (0..self.gate.num_coeffs).flat_map(|i| ReducingExtensionGate::::wires_coeff(i)), + ) .map(|i| Target::wire(self.gate_index, i)) .collect() } @@ -181,16 +183,18 @@ impl, const D: usize> SimpleGenerator for ReducingGenerator< witness.get_extension_target(t) }; - let alpha = extract_extension(ReducingExtGate::::wires_alpha()); - let old_acc = extract_extension(ReducingExtGate::::wires_old_acc()); + let alpha = extract_extension(ReducingExtensionGate::::wires_alpha()); + let old_acc = extract_extension(ReducingExtensionGate::::wires_old_acc()); let coeffs = (0..self.gate.num_coeffs) - .map(|i| extract_extension(ReducingExtGate::::wires_coeff(i))) + .map(|i| extract_extension(ReducingExtensionGate::::wires_coeff(i))) .collect::>(); let accs = (0..self.gate.num_coeffs) .map(|i| ExtensionTarget::from_range(self.gate_index, self.gate.wires_accs(i))) .collect::>(); - let output = - ExtensionTarget::from_range(self.gate_index, ReducingExtGate::::wires_output()); + let output = ExtensionTarget::from_range( + self.gate_index, + ReducingExtensionGate::::wires_output(), + ); let mut acc = old_acc; for i in 0..self.gate.num_coeffs { @@ -208,15 +212,15 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; - use crate::gates::reducing_extension::ReducingExtGate; + use crate::gates::reducing_extension::ReducingExtensionGate; #[test] fn low_degree() { - test_low_degree::(ReducingExtGate::new(22)); + test_low_degree::(ReducingExtensionGate::new(22)); } #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(ReducingExtGate::new(22)) + test_eval_fns::(ReducingExtensionGate::new(22)) } } diff --git a/src/util/reducing.rs b/src/util/reducing.rs index 12be80f6..f2cd3d55 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -3,8 +3,9 @@ use std::borrow::Borrow; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; +use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::reducing::ReducingGate; -use crate::gates::reducing_extension::ReducingExtGate; +use crate::gates::reducing_extension::ReducingExtensionGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::polynomial::polynomial::PolynomialCoeffs; @@ -94,7 +95,7 @@ impl ReducingFactorTarget { Self { base, count: 0 } } - /// Reduces a length `n` vector of `Target`s using `n/21` `ReducingGate`s (with 33 routed wires and 126 wires). + /// Reduces a vector of `Target`s using `ReducingGate`s. pub fn reduce_base( &mut self, terms: &[Target], @@ -103,11 +104,16 @@ impl ReducingFactorTarget { where F: RichField + Extendable, { + let l = terms.len(); + // For small reductions, use an arithmetic gate. + if l - 1 <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops { + return self.reduce_base_arithmetic(terms, builder); + } let max_coeffs_len = ReducingGate::::max_coeffs_len( builder.config.num_wires, builder.config.num_routed_wires, ); - self.count += terms.len() as u64; + self.count += l as u64; let zero = builder.zero(); let zero_ext = builder.zero_extension(); let mut acc = zero_ext; @@ -138,6 +144,26 @@ impl ReducingFactorTarget { acc } + /// Reduces a vector of `Target`s using `ArithmeticGate`s. + fn reduce_base_arithmetic( + &mut self, + terms: &[Target], + builder: &mut CircuitBuilder, + ) -> ExtensionTarget + where + F: RichField + Extendable, + { + self.count += terms.len() as u64; + terms + .iter() + .rev() + .fold(builder.zero_extension(), |acc, &t| { + let et = builder.convert_to_ext(t); + builder.mul_add_extension(self.base, acc, et) + }) + } + + /// Reduces a vector of `ExtensionTarget`s using `ReducingExtensionGate`s. pub fn reduce( &mut self, terms: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. @@ -146,12 +172,16 @@ impl ReducingFactorTarget { where F: RichField + Extendable, { - let max_coeffs_len = ReducingExtGate::::max_coeffs_len( + let l = terms.len(); + // For small reductions, use an arithmetic gate. + if l - 1 <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops { + return self.reduce_arithmetic(terms, builder); + } + let max_coeffs_len = ReducingExtensionGate::::max_coeffs_len( builder.config.num_wires, builder.config.num_routed_wires, ); - self.count += terms.len() as u64; - let zero = builder.zero(); + self.count += l as u64; let zero_ext = builder.zero_extension(); let mut acc = zero_ext; let mut reversed_terms = terms.to_vec(); @@ -160,30 +190,55 @@ impl ReducingFactorTarget { } reversed_terms.reverse(); for chunk in reversed_terms.chunks_exact(max_coeffs_len) { - let gate = ReducingExtGate::new(max_coeffs_len); + let gate = ReducingExtensionGate::new(max_coeffs_len); let gate_index = builder.add_gate(gate.clone(), Vec::new()); builder.connect_extension( self.base, - ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_alpha()), + ExtensionTarget::from_range(gate_index, ReducingExtensionGate::::wires_alpha()), ); builder.connect_extension( acc, - ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_old_acc()), + ExtensionTarget::from_range( + gate_index, + ReducingExtensionGate::::wires_old_acc(), + ), ); for (i, &t) in chunk.iter().enumerate() { builder.connect_extension( t, - ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_coeff(i)), + ExtensionTarget::from_range( + gate_index, + ReducingExtensionGate::::wires_coeff(i), + ), ); } - acc = ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_output()); + acc = + ExtensionTarget::from_range(gate_index, ReducingExtensionGate::::wires_output()); } acc } + /// Reduces a vector of `ExtensionTarget`s using `ArithmeticGate`s. + fn reduce_arithmetic( + &mut self, + terms: &[ExtensionTarget], + builder: &mut CircuitBuilder, + ) -> ExtensionTarget + where + F: RichField + Extendable, + { + self.count += terms.len() as u64; + terms + .iter() + .rev() + .fold(builder.zero_extension(), |acc, &et| { + builder.mul_add_extension(self.base, acc, et) + }) + } + pub fn shift( &mut self, x: ExtensionTarget, From f787c5385f1591cd23c59ebd450b243270052ed3 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 15 Nov 2021 11:50:27 +0100 Subject: [PATCH 095/202] Simplify --- src/util/reducing.rs | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/src/util/reducing.rs b/src/util/reducing.rs index f2cd3d55..f10f412d 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -105,10 +105,16 @@ impl ReducingFactorTarget { F: RichField + Extendable, { let l = terms.len(); + // For small reductions, use an arithmetic gate. if l - 1 <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops { - return self.reduce_base_arithmetic(terms, builder); + let terms_ext = terms + .iter() + .map(|&t| builder.convert_to_ext(t)) + .collect::>(); + return self.reduce_arithmetic(&terms_ext, builder); } + let max_coeffs_len = ReducingGate::::max_coeffs_len( builder.config.num_wires, builder.config.num_routed_wires, @@ -144,25 +150,6 @@ impl ReducingFactorTarget { acc } - /// Reduces a vector of `Target`s using `ArithmeticGate`s. - fn reduce_base_arithmetic( - &mut self, - terms: &[Target], - builder: &mut CircuitBuilder, - ) -> ExtensionTarget - where - F: RichField + Extendable, - { - self.count += terms.len() as u64; - terms - .iter() - .rev() - .fold(builder.zero_extension(), |acc, &t| { - let et = builder.convert_to_ext(t); - builder.mul_add_extension(self.base, acc, et) - }) - } - /// Reduces a vector of `ExtensionTarget`s using `ReducingExtensionGate`s. pub fn reduce( &mut self, @@ -173,10 +160,12 @@ impl ReducingFactorTarget { F: RichField + Extendable, { let l = terms.len(); + // For small reductions, use an arithmetic gate. if l - 1 <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops { return self.reduce_arithmetic(terms, builder); } + let max_coeffs_len = ReducingExtensionGate::::max_coeffs_len( builder.config.num_wires, builder.config.num_routed_wires, From 3efe2068bc99703e033f75ee34145fc102ebe44b Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 15 Nov 2021 11:59:54 +0100 Subject: [PATCH 096/202] Minor --- src/gates/reducing_extension.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/gates/reducing_extension.rs b/src/gates/reducing_extension.rs index 9ee2134f..99d40fba 100644 --- a/src/gates/reducing_extension.rs +++ b/src/gates/reducing_extension.rs @@ -7,7 +7,7 @@ use crate::field::field_types::RichField; use crate::gates::gate::Gate; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; -use crate::iop::witness::{PartialWitness, PartitionWitness, Witness}; +use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; @@ -170,9 +170,7 @@ impl, const D: usize> SimpleGenerator for ReducingGenerator< fn dependencies(&self) -> Vec { ReducingExtensionGate::::wires_alpha() .chain(ReducingExtensionGate::::wires_old_acc()) - .chain( - (0..self.gate.num_coeffs).flat_map(|i| ReducingExtensionGate::::wires_coeff(i)), - ) + .chain((0..self.gate.num_coeffs).flat_map(ReducingExtensionGate::::wires_coeff)) .map(|i| Target::wire(self.gate_index, i)) .collect() } From 49e4307820bb298e9b1bcbff2a8ca91905b5bf23 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 15 Nov 2021 13:35:21 +0100 Subject: [PATCH 097/202] Comments + test for reducing 100 extension elements --- src/gates/reducing_extension.rs | 16 +++++++--------- src/util/reducing.rs | 5 +++++ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/gates/reducing_extension.rs b/src/gates/reducing_extension.rs index 99d40fba..532b484f 100644 --- a/src/gates/reducing_extension.rs +++ b/src/gates/reducing_extension.rs @@ -23,6 +23,8 @@ impl ReducingExtensionGate { } pub fn max_coeffs_len(num_wires: usize, num_routed_wires: usize) -> usize { + // `3*D` routed wires are used for the output, alpha and old accumulator. + // Need `num_coeffs*D` routed wires for coeffs, and `(num_coeffs-1)*D` wires for accumulators. ((num_routed_wires - 3 * D) / D).min((num_wires - 2 * D) / (D * 2)) } @@ -43,6 +45,7 @@ impl ReducingExtensionGate { Self::START_COEFFS + self.num_coeffs * D } fn wires_accs(&self, i: usize) -> Range { + debug_assert!(i < self.num_coeffs); if i == self.num_coeffs - 1 { // The last accumulator is the output. return Self::wires_output(); @@ -176,23 +179,19 @@ impl, const D: usize> SimpleGenerator for ReducingGenerator< } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let extract_extension = |range: Range| -> F::Extension { + let local_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.gate_index, range); witness.get_extension_target(t) }; - let alpha = extract_extension(ReducingExtensionGate::::wires_alpha()); - let old_acc = extract_extension(ReducingExtensionGate::::wires_old_acc()); + let alpha = local_extension(ReducingExtensionGate::::wires_alpha()); + let old_acc = local_extension(ReducingExtensionGate::::wires_old_acc()); let coeffs = (0..self.gate.num_coeffs) - .map(|i| extract_extension(ReducingExtensionGate::::wires_coeff(i))) + .map(|i| local_extension(ReducingExtensionGate::::wires_coeff(i))) .collect::>(); let accs = (0..self.gate.num_coeffs) .map(|i| ExtensionTarget::from_range(self.gate_index, self.gate.wires_accs(i))) .collect::>(); - let output = ExtensionTarget::from_range( - self.gate_index, - ReducingExtensionGate::::wires_output(), - ); let mut acc = old_acc; for i in 0..self.gate.num_coeffs { @@ -200,7 +199,6 @@ impl, const D: usize> SimpleGenerator for ReducingGenerator< out_buffer.set_extension_target(accs[i], computed_acc); acc = computed_acc; } - out_buffer.set_extension_target(output, acc); } } diff --git a/src/util/reducing.rs b/src/util/reducing.rs index f10f412d..1cddc3e7 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -330,4 +330,9 @@ mod tests { fn test_reduce_gadget_base_100() -> Result<()> { test_reduce_gadget_base(100) } + + #[test] + fn test_reduce_gadget_100() -> Result<()> { + test_reduce_gadget(100) + } } From efab3177ce45a37dea9f0851b70b50b0f8527d3d Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 15 Nov 2021 09:55:06 -0800 Subject: [PATCH 098/202] Have `le_sum` use arithmetic ops if it's cheaper (#362) * Have le_sum use arithmetic ops if it's cheaper * fmt --- src/gadgets/arithmetic.rs | 4 ++-- src/gadgets/split_base.rs | 33 ++++++++++++++++++++------------- src/plonk/circuit_builder.rs | 14 ++++++++++++++ 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 3fe90019..a053b761 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -215,7 +215,7 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of `2^power_log`. pub fn exp_power_of_2(&mut self, base: Target, power_log: usize) -> Target { - if power_log > ArithmeticGate::new_from_config(&self.config).num_ops { + if power_log > self.num_base_arithmetic_ops_per_gate() { // Cheaper to just use `ExponentiateGate`. return self.exp_u64(base, 1 << power_log); } @@ -269,7 +269,7 @@ impl, const D: usize> CircuitBuilder { let base_t = self.constant(base); let exponent_bits: Vec<_> = exponent_bits.into_iter().map(|b| *b.borrow()).collect(); - if exponent_bits.len() > ArithmeticGate::new_from_config(&self.config).num_ops { + if exponent_bits.len() > self.num_base_arithmetic_ops_per_gate() { // Cheaper to just use `ExponentiateGate`. return self.exp_from_bits(base_t, exponent_bits); } diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index 30bdea6a..d60324ce 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -1,5 +1,7 @@ use std::borrow::Borrow; +use itertools::Itertools; + use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::gates::base_sum::BaseSumGate; @@ -29,21 +31,26 @@ impl, const D: usize> CircuitBuilder { /// the number with little-endian bit representation given by `bits`. pub(crate) fn le_sum( &mut self, - bits: impl ExactSizeIterator> + Clone, + mut bits: impl Iterator>, ) -> Target { + let bits = bits.map(|b| *b.borrow()).collect_vec(); let num_bits = bits.len(); if num_bits == 0 { return self.zero(); - } else if num_bits == 1 { - let mut bits = bits; - return bits.next().unwrap().borrow().target; - } else if num_bits == 2 { - let two = self.two(); - let mut bits = bits; - let b0 = bits.next().unwrap().borrow().target; - let b1 = bits.next().unwrap().borrow().target; - return self.mul_add(two, b1, b0); } + + // Check if it's cheaper to just do this with arithmetic operations. + let arithmetic_ops = num_bits - 1; + if arithmetic_ops <= self.num_base_arithmetic_ops_per_gate() { + let two = self.two(); + let mut rev_bits = bits.iter().rev(); + let mut sum = rev_bits.next().unwrap().target; + for &bit in rev_bits { + sum = self.mul_add(two, sum, bit.target); + } + return sum; + } + debug_assert!( BaseSumGate::<2>::START_LIMBS + num_bits <= self.config.num_routed_wires, "Not enough routed wires." @@ -51,10 +58,10 @@ impl, const D: usize> CircuitBuilder { let gate_type = BaseSumGate::<2>::new_from_config::(&self.config); let gate_index = self.add_gate(gate_type, vec![]); for (limb, wire) in bits - .clone() + .iter() .zip(BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + num_bits) { - self.connect(limb.borrow().target, Target::wire(gate_index, wire)); + self.connect(limb.target, Target::wire(gate_index, wire)); } for l in gate_type.limbs().skip(num_bits) { self.assert_zero(Target::wire(gate_index, l)); @@ -62,7 +69,7 @@ impl, const D: usize> CircuitBuilder { self.add_simple_generator(BaseSumGenerator::<2> { gate_index, - limbs: bits.map(|l| *l.borrow()).collect(), + limbs: bits, }); Target::wire(gate_index, BaseSumGate::<2>::WIRE_SUM) diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index aac9d42e..4425f193 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -379,6 +379,20 @@ impl, const D: usize> CircuitBuilder { } } + /// The number of (base field) `arithmetic` operations that can be performed in a single gate. + pub(crate) fn num_base_arithmetic_ops_per_gate(&self) -> usize { + if self.config.use_base_arithmetic_gate { + ArithmeticGate::new_from_config(&self.config).num_ops + } else { + self.num_ext_arithmetic_ops_per_gate() + } + } + + /// The number of `arithmetic_extension` operations that can be performed in a single gate. + pub(crate) fn num_ext_arithmetic_ops_per_gate(&self) -> usize { + ArithmeticExtensionGate::::new_from_config(&self.config).num_ops + } + /// The number of polynomial values that will be revealed per opening, both for the "regular" /// polynomials and for the Z polynomials. Because calculating these values involves a recursive /// dependence (the amount of blinding depends on the degree, which depends on the blinding), From 07d03465b1db5cf609f34fa2ce564e0a3f490c5d Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 15 Nov 2021 10:03:13 -0800 Subject: [PATCH 099/202] Verify that non-canonical splits are OK (#357) The effect on soundness error is negligible for our current field, but this introduces an assertion that could fail if we changed to a field with more elements in the "ambiguous" range. --- src/field/field_types.rs | 5 +++++ src/fri/recursive_verifier.rs | 30 +++++++++++++++++++++++++++--- src/plonk/circuit_data.rs | 4 ++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/src/field/field_types.rs b/src/field/field_types.rs index f5d06fdb..aed654d5 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -345,6 +345,11 @@ pub trait Field: pub trait PrimeField: Field { const ORDER: u64; + /// The number of bits required to encode any field element. + fn bits() -> usize { + bits_u64(Self::NEG_ONE.to_canonical_u64()) + } + fn to_canonical_u64(&self) -> u64; fn to_noncanonical_u64(&self) -> u64; diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index fc0a7341..fd065288 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -10,7 +10,7 @@ use crate::hash::hash_types::MerkleCapTarget; use crate::iop::challenger::RecursiveChallenger; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::circuit_data::CommonCircuitData; +use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData}; use crate::plonk::plonk_common::PlonkPolynomials; use crate::plonk::proof::OpeningSetTarget; use crate::util::reducing::ReducingFactorTarget; @@ -305,9 +305,13 @@ impl, const D: usize> CircuitBuilder { common_data: &CommonCircuitData, ) { let n_log = log2_strict(n); - // TODO: Do we need to range check `x_index` to a target smaller than `p`? + + // Note that this `low_bits` decomposition permits non-canonical binary encodings. Here we + // verify that this has a negligible impact on soundness error. + Self::assert_noncanonical_indices_ok(&common_data.config); let x_index = challenger.get_challenge(self); - let mut x_index_bits = self.low_bits(x_index, n_log, 64); + let mut x_index_bits = self.low_bits(x_index, n_log, F::bits()); + let cap_index = self.le_sum(x_index_bits[x_index_bits.len() - common_data.config.cap_height..].iter()); with_context!( @@ -408,6 +412,26 @@ impl, const D: usize> CircuitBuilder { ); self.connect_extension(eval, old_eval); } + + /// We decompose FRI query indices into bits without verifying that the decomposition given by + /// the prover is the canonical one. In particular, if `x_index < 2^field_bits - p`, then the + /// prover could supply the binary encoding of either `x_index` or `x_index + p`, since the are + /// congruent mod `p`. However, this only occurs with probability + /// p_ambiguous = (2^field_bits - p) / p + /// which is small for the field that we use in practice. + /// + /// In particular, the soundness error of one FRI query is roughly the codeword rate, which + /// is much larger than this ambiguous-element probability given any reasonable parameters. + /// Thus ambiguous elements contribute a negligible amount to soundness error. + /// + /// Here we compare the probabilities as a sanity check, to verify the claim above. + fn assert_noncanonical_indices_ok(config: &CircuitConfig) { + let num_ambiguous_elems = u64::MAX - F::ORDER + 1; + let query_error = config.rate(); + let p_ambiguous = (num_ambiguous_elems as f64) / (F::ORDER as f64); + assert!(p_ambiguous < query_error * 1e-5, + "A non-negligible portion of field elements are in the range that permits non-canonical encodings. Need to do more analysis or enforce canonical encodings."); + } } #[derive(Copy, Clone)] diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 564d558d..a26814c3 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -48,6 +48,10 @@ impl Default for CircuitConfig { } impl CircuitConfig { + pub fn rate(&self) -> f64 { + 1.0 / ((1 << self.rate_bits) as f64) + } + pub fn num_advice_wires(&self) -> usize { self.num_wires - self.num_routed_wires } From 640997639a716459289453fdf1ffd296f46da56e Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 15 Nov 2021 10:10:19 -0800 Subject: [PATCH 100/202] Rename z_gz -> z_gx (#359) Elsewhere we refer to the point we're evaluating at as `x` --- src/plonk/vanishing_poly.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index ef322c9f..58820725 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -42,7 +42,7 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( for i in 0..common_data.config.num_challenges { let z_x = local_zs[i]; - let z_gz = next_zs[i]; + let z_gx = next_zs[i]; vanishing_z_1_terms.push(l1_x * (z_x - F::Extension::ONE)); let numerator_values = (0..common_data.config.num_routed_wires) @@ -69,7 +69,7 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( &denominator_values, current_partial_products, z_x, - z_gz, + z_gx, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); @@ -145,7 +145,7 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let l1_x = z_h_on_coset.eval_l1(index, x); for i in 0..num_challenges { let z_x = local_zs[i]; - let z_gz = next_zs[i]; + let z_gx = next_zs[i]; vanishing_z_1_terms.push(l1_x * z_x.sub_one()); numerator_values.extend((0..num_routed_wires).map(|j| { @@ -168,7 +168,7 @@ pub(crate) fn eval_vanishing_poly_base_batch, const &denominator_values, current_partial_products, z_x, - z_gz, + z_gx, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); @@ -320,7 +320,7 @@ pub(crate) fn eval_vanishing_poly_recursively, cons for i in 0..common_data.config.num_challenges { let z_x = local_zs[i]; - let z_gz = next_zs[i]; + let z_gx = next_zs[i]; // L_1(x) Z(x) = 0. vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x)); @@ -352,7 +352,7 @@ pub(crate) fn eval_vanishing_poly_recursively, cons &denominator_values, current_partial_products, z_x, - z_gz, + z_gx, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); From 239c795a9d59be4f569b5774f939c5d3a4264dc3 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 15 Nov 2021 10:10:37 -0800 Subject: [PATCH 101/202] Address some more arithmetic gates that have unique constants (#361) Saves 131 gates, though only when not using `PoseidonMdsGate`, so not relevant for the 2^12 branch. --- src/gates/poseidon.rs | 8 ++++---- src/hash/poseidon.rs | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 59c23b44..ccec2445 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -351,10 +351,10 @@ where let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(builder.sub_extension(state[0], sbox_in)); state[0] = >::sbox_monomial_recursive(builder, sbox_in); - state[0] = builder.add_const_extension( - state[0], - F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]), - ); + let c = >::FAST_PARTIAL_ROUND_CONSTANTS[r]; + let c = F::Extension::from_canonical_u64(c); + let c = builder.constant_extension(c); + state[0] = builder.add_extension(state[0], c); state = >::mds_partial_layer_fast_recursive(builder, &state, r); } diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 9e4dd7f4..4450e659 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -455,8 +455,9 @@ where ); for i in 1..WIDTH { let t = >::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]; - let t = Self::from_canonical_u64(t); - d = builder.mul_const_add_extension(t, state[i], d); + let t = Self::Extension::from_canonical_u64(t); + let t = builder.constant_extension(t); + d = builder.mul_add_extension(t, state[i], d); } let mut result = [builder.zero_extension(); WIDTH]; From 9aafa447f87d04254cc9efa0e6c00665a15dd4cd Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 15 Nov 2021 10:11:16 -0800 Subject: [PATCH 102/202] Fix stack overflows due to recursion in `Forest::find` (#358) --- src/plonk/permutation_argument.rs | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/plonk/permutation_argument.rs b/src/plonk/permutation_argument.rs index ee9474d7..28a07dff 100644 --- a/src/plonk/permutation_argument.rs +++ b/src/plonk/permutation_argument.rs @@ -45,15 +45,23 @@ impl Forest { } /// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives. - pub fn find(&mut self, x_index: usize) -> usize { - let x_parent = self.parents[x_index]; - if x_parent != x_index { - let root_index = self.find(x_parent); - self.parents[x_index] = root_index; - root_index - } else { - x_index + pub fn find(&mut self, mut x_index: usize) -> usize { + // Note: We avoid recursion here since the chains can be long, causing stack overflows. + + // First, find the representative of the set containing `x_index`. + let mut representative = x_index; + while self.parents[representative] != representative { + representative = self.parents[representative]; } + + // Then, update each node in this chain to point directly to the representative. + while self.parents[x_index] != x_index { + let old_parent = self.parents[x_index]; + self.parents[x_index] = representative; + x_index = old_parent; + } + + representative } /// Merge two sets. From 8ea6c4d3926654e53e1354960905c4ff54e91edc Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 15 Nov 2021 10:15:55 -0800 Subject: [PATCH 103/202] Different implementation of `RandomAccessGate` (#360) The previous code used an equality test for each index. This variant uses a "MUX tree" instead. If we imagine the items as being the leaves of a binary tree, we can compute the `i`th item by splitting `i` into bits, then performing a "select" operation for each node. The bit used in each select is based on the height of the associated node. This uses fewer wires and is cheaper to evaluate, saving 31 wires in the recursion circuit. A potential disadvantage is that this uses higher-degree constraints (degree 4 with our params), but I don't think this is much of a concern for us since we use a degree-9 constraint system. --- src/fri/recursive_verifier.rs | 18 +-- src/gadgets/random_access.rs | 6 +- src/gates/random_access.rs | 259 ++++++++++++++++++---------------- src/plonk/circuit_builder.rs | 32 ++--- 4 files changed, 162 insertions(+), 153 deletions(-) diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index fd065288..a08dd99e 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -49,12 +49,12 @@ impl, const D: usize> CircuitBuilder { /// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check /// isn't required -- without it we'd get errors elsewhere in the stack -- but just gives more /// helpful errors. - fn check_recursion_config(&self, max_fri_arity: usize) { + fn check_recursion_config(&self, max_fri_arity_bits: usize) { let random_access = RandomAccessGate::::new_from_config( &self.config, - max_fri_arity.max(1 << self.config.cap_height), + max_fri_arity_bits.max(self.config.cap_height), ); - let interpolation_gate = InterpolationGate::::new(log2_strict(max_fri_arity)); + let interpolation_gate = InterpolationGate::::new(max_fri_arity_bits); let min_wires = random_access .num_wires() @@ -65,15 +65,15 @@ impl, const D: usize> CircuitBuilder { assert!( self.config.num_wires >= min_wires, - "To efficiently perform FRI checks with an arity of {}, at least {} wires are needed. Consider reducing arity.", - max_fri_arity, + "To efficiently perform FRI checks with an arity of 2^{}, at least {} wires are needed. Consider reducing arity.", + max_fri_arity_bits, min_wires ); assert!( self.config.num_routed_wires >= min_routed_wires, - "To efficiently perform FRI checks with an arity of {}, at least {} routed wires are needed. Consider reducing arity.", - max_fri_arity, + "To efficiently perform FRI checks with an arity of 2^{}, at least {} routed wires are needed. Consider reducing arity.", + max_fri_arity_bits, min_routed_wires ); } @@ -107,8 +107,8 @@ impl, const D: usize> CircuitBuilder { ) { let config = &common_data.config; - if let Some(max_arity) = common_data.fri_params.max_arity() { - self.check_recursion_config(max_arity); + if let Some(max_arity_bits) = common_data.fri_params.max_arity_bits() { + self.check_recursion_config(max_arity_bits); } debug_assert_eq!( diff --git a/src/gadgets/random_access.rs b/src/gadgets/random_access.rs index 58c827c1..c028fa68 100644 --- a/src/gadgets/random_access.rs +++ b/src/gadgets/random_access.rs @@ -4,18 +4,20 @@ use crate::field::field_types::RichField; use crate::gates::random_access::RandomAccessGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::log2_strict; impl, const D: usize> CircuitBuilder { /// Checks that a `Target` matches a vector at a non-deterministic index. /// Note: `access_index` is not range-checked. pub fn random_access(&mut self, access_index: Target, claimed_element: Target, v: Vec) { let vec_size = v.len(); + let bits = log2_strict(vec_size); debug_assert!(vec_size > 0); if vec_size == 1 { return self.connect(claimed_element, v[0]); } - let (gate_index, copy) = self.find_random_access_gate(vec_size); - let dummy_gate = RandomAccessGate::::new_from_config(&self.config, vec_size); + let (gate_index, copy) = self.find_random_access_gate(bits); + let dummy_gate = RandomAccessGate::::new_from_config(&self.config, bits); v.iter().enumerate().for_each(|(i, &val)| { self.connect( diff --git a/src/gates/random_access.rs b/src/gates/random_access.rs index bdbff667..cea4b079 100644 --- a/src/gates/random_access.rs +++ b/src/gates/random_access.rs @@ -1,5 +1,7 @@ use std::marker::PhantomData; +use itertools::Itertools; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; @@ -15,75 +17,64 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// A gate for checking that a particular element of a list matches a given value. #[derive(Copy, Clone, Debug)] pub(crate) struct RandomAccessGate, const D: usize> { - pub vec_size: usize, + pub bits: usize, pub num_copies: usize, _phantom: PhantomData, } impl, const D: usize> RandomAccessGate { - pub fn new(num_copies: usize, vec_size: usize) -> Self { + fn new(num_copies: usize, bits: usize) -> Self { Self { - vec_size, + bits, num_copies, _phantom: PhantomData, } } - pub fn new_from_config(config: &CircuitConfig, vec_size: usize) -> Self { - let num_copies = Self::max_num_copies(config.num_routed_wires, config.num_wires, vec_size); - Self::new(num_copies, vec_size) + pub fn new_from_config(config: &CircuitConfig, bits: usize) -> Self { + let vec_size = 1 << bits; + // Need `(2 + vec_size) * num_copies` routed wires + let max_copies = (config.num_routed_wires / (2 + vec_size)).min( + // Need `(2 + vec_size + bits) * num_copies` wires + config.num_wires / (2 + vec_size + bits), + ); + Self::new(max_copies, bits) } - pub fn max_num_copies(num_routed_wires: usize, num_wires: usize, vec_size: usize) -> usize { - // Need `(2 + vec_size) * num_copies` routed wires - (num_routed_wires / (2 + vec_size)).min( - // Need `(2 + 3*vec_size) * num_copies` wires - num_wires / (2 + 3 * vec_size), - ) + fn vec_size(&self) -> usize { + 1 << self.bits } pub fn wire_access_index(&self, copy: usize) -> usize { debug_assert!(copy < self.num_copies); - (2 + self.vec_size) * copy + (2 + self.vec_size()) * copy } pub fn wire_claimed_element(&self, copy: usize) -> usize { debug_assert!(copy < self.num_copies); - (2 + self.vec_size) * copy + 1 + (2 + self.vec_size()) * copy + 1 } pub fn wire_list_item(&self, i: usize, copy: usize) -> usize { - debug_assert!(i < self.vec_size); + debug_assert!(i < self.vec_size()); debug_assert!(copy < self.num_copies); - (2 + self.vec_size) * copy + 2 + i + (2 + self.vec_size()) * copy + 2 + i } fn start_of_intermediate_wires(&self) -> usize { - (2 + self.vec_size) * self.num_copies + (2 + self.vec_size()) * self.num_copies } pub(crate) fn num_routed_wires(&self) -> usize { self.start_of_intermediate_wires() } - /// An intermediate wire for a dummy variable used to show equality. - /// The prover sets this to 1/(x-y) if x != y, or to an arbitrary value if - /// x == y. - pub fn wire_equality_dummy_for_index(&self, i: usize, copy: usize) -> usize { - debug_assert!(i < self.vec_size); + /// An intermediate wire where the prover gives the (purported) binary decomposition of the + /// index. + pub fn wire_bit(&self, i: usize, copy: usize) -> usize { + debug_assert!(i < self.bits); debug_assert!(copy < self.num_copies); - self.start_of_intermediate_wires() + copy * self.vec_size + i - } - - /// An intermediate wire for the "index_matches" variable (1 if the current index is the index at - /// which to compare, 0 otherwise). - pub fn wire_index_matches_for_index(&self, i: usize, copy: usize) -> usize { - debug_assert!(i < self.vec_size); - debug_assert!(copy < self.num_copies); - self.start_of_intermediate_wires() - + self.vec_size * self.num_copies - + self.vec_size * copy - + i + self.start_of_intermediate_wires() + copy * self.bits + i } } @@ -97,23 +88,38 @@ impl, const D: usize> Gate for RandomAccessGa for copy in 0..self.num_copies { let access_index = vars.local_wires[self.wire_access_index(copy)]; - let list_items = (0..self.vec_size) + let mut list_items = (0..self.vec_size()) .map(|i| vars.local_wires[self.wire_list_item(i, copy)]) .collect::>(); let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; + let bits = (0..self.bits) + .map(|i| vars.local_wires[self.wire_bit(i, copy)]) + .collect::>(); - for i in 0..self.vec_size { - let cur_index = F::Extension::from_canonical_usize(i); - let difference = cur_index - access_index; - let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)]; - let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)]; - - // The two index equality constraints. - constraints.push(difference * equality_dummy - (F::Extension::ONE - index_matches)); - constraints.push(index_matches * difference); - // Value equality constraint. - constraints.push((list_items[i] - claimed_element) * index_matches); + // Assert that each bit wire value is indeed boolean. + for &b in &bits { + constraints.push(b * (b - F::Extension::ONE)); } + + // Assert that the binary decomposition was correct. + let reconstructed_index = bits + .iter() + .rev() + .fold(F::Extension::ZERO, |acc, &b| acc.double() + b); + constraints.push(reconstructed_index - access_index); + + // Repeatedly fold the list, selecting the left or right item from each pair based on + // the corresponding bit. + for b in bits { + list_items = list_items + .iter() + .tuples() + .map(|(&x, &y)| x + b * (y - x)) + .collect() + } + + debug_assert_eq!(list_items.len(), 1); + constraints.push(list_items[0] - claimed_element); } constraints @@ -124,23 +130,35 @@ impl, const D: usize> Gate for RandomAccessGa for copy in 0..self.num_copies { let access_index = vars.local_wires[self.wire_access_index(copy)]; - let list_items = (0..self.vec_size) + let mut list_items = (0..self.vec_size()) .map(|i| vars.local_wires[self.wire_list_item(i, copy)]) .collect::>(); let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; + let bits = (0..self.bits) + .map(|i| vars.local_wires[self.wire_bit(i, copy)]) + .collect::>(); - for i in 0..self.vec_size { - let cur_index = F::from_canonical_usize(i); - let difference = cur_index - access_index; - let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)]; - let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)]; - - // The two index equality constraints. - constraints.push(difference * equality_dummy - (F::ONE - index_matches)); - constraints.push(index_matches * difference); - // Value equality constraint. - constraints.push((list_items[i] - claimed_element) * index_matches); + // Assert that each bit wire value is indeed boolean. + for &b in &bits { + constraints.push(b * (b - F::ONE)); } + + // Assert that the binary decomposition was correct. + let reconstructed_index = bits.iter().rev().fold(F::ZERO, |acc, &b| acc.double() + b); + constraints.push(reconstructed_index - access_index); + + // Repeatedly fold the list, selecting the left or right item from each pair based on + // the corresponding bit. + for b in bits { + list_items = list_items + .iter() + .tuples() + .map(|(&x, &y)| x + b * (y - x)) + .collect() + } + + debug_assert_eq!(list_items.len(), 1); + constraints.push(list_items[0] - claimed_element); } constraints @@ -151,36 +169,44 @@ impl, const D: usize> Gate for RandomAccessGa builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { + let zero = builder.zero_extension(); + let two = builder.two_extension(); let mut constraints = Vec::with_capacity(self.num_constraints()); for copy in 0..self.num_copies { let access_index = vars.local_wires[self.wire_access_index(copy)]; - let list_items = (0..self.vec_size) + let mut list_items = (0..self.vec_size()) .map(|i| vars.local_wires[self.wire_list_item(i, copy)]) .collect::>(); let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; + let bits = (0..self.bits) + .map(|i| vars.local_wires[self.wire_bit(i, copy)]) + .collect::>(); - for i in 0..self.vec_size { - let cur_index_ext = F::Extension::from_canonical_usize(i); - let cur_index = builder.constant_extension(cur_index_ext); - let difference = builder.sub_extension(cur_index, access_index); - let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)]; - let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)]; - - let one = builder.one_extension(); - let not_index_matches = builder.sub_extension(one, index_matches); - let first_equality_constraint = - builder.mul_sub_extension(difference, equality_dummy, not_index_matches); - constraints.push(first_equality_constraint); - - let second_equality_constraint = builder.mul_extension(index_matches, difference); - constraints.push(second_equality_constraint); - - // Output constraint. - let diff = builder.sub_extension(list_items[i], claimed_element); - let conditional_diff = builder.mul_extension(index_matches, diff); - constraints.push(conditional_diff); + // Assert that each bit wire value is indeed boolean. + for &b in &bits { + constraints.push(builder.mul_sub_extension(b, b, b)); } + + // Assert that the binary decomposition was correct. + let reconstructed_index = bits + .iter() + .rev() + .fold(zero, |acc, &b| builder.mul_add_extension(acc, two, b)); + constraints.push(builder.sub_extension(reconstructed_index, access_index)); + + // Repeatedly fold the list, selecting the left or right item from each pair based on + // the corresponding bit. + for b in bits { + list_items = list_items + .iter() + .tuples() + .map(|(&x, &y)| builder.select_ext_generalized(b, y, x)) + .collect() + } + + debug_assert_eq!(list_items.len(), 1); + constraints.push(builder.sub_extension(list_items[0], claimed_element)); } constraints @@ -207,7 +233,7 @@ impl, const D: usize> Gate for RandomAccessGa } fn num_wires(&self) -> usize { - self.wire_index_matches_for_index(self.vec_size - 1, self.num_copies - 1) + 1 + self.wire_bit(self.bits - 1, self.num_copies - 1) + 1 } fn num_constants(&self) -> usize { @@ -215,11 +241,12 @@ impl, const D: usize> Gate for RandomAccessGa } fn degree(&self) -> usize { - 2 + self.bits + 1 } fn num_constraints(&self) -> usize { - 3 * self.num_copies * self.vec_size + let constraints_per_copy = self.bits + 2; + self.num_copies * constraints_per_copy } } @@ -238,8 +265,7 @@ impl, const D: usize> SimpleGenerator let mut deps = Vec::new(); deps.push(local_target(self.gate.wire_access_index(self.copy))); - deps.push(local_target(self.gate.wire_claimed_element(self.copy))); - for i in 0..self.gate.vec_size { + for i in 0..self.gate.vec_size() { deps.push(local_target(self.gate.wire_list_item(i, self.copy))); } deps @@ -252,11 +278,12 @@ impl, const D: usize> SimpleGenerator }; let get_local_wire = |input| witness.get_wire(local_wire(input)); + let mut set_local_wire = |input, value| out_buffer.set_wire(local_wire(input), value); - // Compute the new vector and the values for equality_dummy and index_matches - let vec_size = self.gate.vec_size; - let access_index_f = get_local_wire(self.gate.wire_access_index(self.copy)); + let copy = self.copy; + let vec_size = self.gate.vec_size(); + let access_index_f = get_local_wire(self.gate.wire_access_index(copy)); let access_index = access_index_f.to_canonical_u64() as usize; debug_assert!( access_index < vec_size, @@ -265,22 +292,14 @@ impl, const D: usize> SimpleGenerator vec_size ); - for i in 0..vec_size { - let equality_dummy_wire = - local_wire(self.gate.wire_equality_dummy_for_index(i, self.copy)); - let index_matches_wire = - local_wire(self.gate.wire_index_matches_for_index(i, self.copy)); + set_local_wire( + self.gate.wire_claimed_element(copy), + get_local_wire(self.gate.wire_list_item(access_index, copy)), + ); - if i == access_index { - out_buffer.set_wire(equality_dummy_wire, F::ONE); - out_buffer.set_wire(index_matches_wire, F::ONE); - } else { - out_buffer.set_wire( - equality_dummy_wire, - (F::from_canonical_usize(i) - F::from_canonical_usize(access_index)).inverse(), - ); - out_buffer.set_wire(index_matches_wire, F::ZERO); - } + for i in 0..self.gate.bits { + let bit = F::from_bool(((access_index >> i) & 1) != 0); + set_local_wire(self.gate.wire_bit(i, copy), bit); } } } @@ -320,6 +339,7 @@ mod tests { /// Returns the local wires for a random access gate given the vectors, elements to compare, /// and indices. fn get_wires( + bits: usize, lists: Vec>, access_indices: Vec, claimed_elements: Vec, @@ -328,8 +348,7 @@ mod tests { let vec_size = lists[0].len(); let mut v = Vec::new(); - let mut equality_dummy_vals = Vec::new(); - let mut index_matches_vals = Vec::new(); + let mut bit_vals = Vec::new(); for copy in 0..num_copies { let access_index = access_indices[copy]; v.push(F::from_canonical_usize(access_index)); @@ -338,26 +357,17 @@ mod tests { v.push(lists[copy][j]); } - for i in 0..vec_size { - if i == access_index { - equality_dummy_vals.push(F::ONE); - index_matches_vals.push(F::ONE); - } else { - equality_dummy_vals.push( - (F::from_canonical_usize(i) - F::from_canonical_usize(access_index)) - .inverse(), - ); - index_matches_vals.push(F::ZERO); - } + for i in 0..bits { + bit_vals.push(F::from_bool(((access_index >> i) & 1) != 0)); } } - v.extend(equality_dummy_vals); - v.extend(index_matches_vals); + v.extend(bit_vals); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() } - let vec_size = 3; + let bits = 3; + let vec_size = 1 << bits; let num_copies = 4; let lists = (0..num_copies) .map(|_| F::rand_vec(vec_size)) @@ -366,7 +376,7 @@ mod tests { .map(|_| thread_rng().gen_range(0..vec_size)) .collect::>(); let gate = RandomAccessGate:: { - vec_size, + bits, num_copies, _phantom: PhantomData, }; @@ -378,13 +388,18 @@ mod tests { .collect(); let good_vars = EvaluationVars { local_constants: &[], - local_wires: &get_wires(lists.clone(), access_indices.clone(), good_claimed_elements), + local_wires: &get_wires( + bits, + lists.clone(), + access_indices.clone(), + good_claimed_elements, + ), public_inputs_hash: &HashOut::rand(), }; let bad_claimed_elements = F::rand_vec(4); let bad_vars = EvaluationVars { local_constants: &[], - local_wires: &get_wires(lists, access_indices, bad_claimed_elements), + local_wires: &get_wires(bits, lists, access_indices, bad_claimed_elements), public_inputs_hash: &HashOut::rand(), }; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 4425f193..9cd380be 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -770,8 +770,8 @@ pub struct BatchedGates, const D: usize> { pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, pub(crate) free_base_arithmetic: HashMap<(F, F), (usize, usize)>, - /// A map `(c0, c1) -> (g, i)` from constants `vec_size` to an available arithmetic gate using - /// these constants with gate index `g` and already using `i` random accesses. + /// A map `b -> (g, i)` from `b` bits to an available random access gate of that size with gate + /// index `g` and already using `i` random accesses. pub(crate) free_random_access: HashMap, /// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value @@ -869,32 +869,27 @@ impl, const D: usize> CircuitBuilder { /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index /// `g` and the gate's `i`-th random access is available. - pub(crate) fn find_random_access_gate(&mut self, vec_size: usize) -> (usize, usize) { + pub(crate) fn find_random_access_gate(&mut self, bits: usize) -> (usize, usize) { let (gate, i) = self .batched_gates .free_random_access - .get(&vec_size) + .get(&bits) .copied() .unwrap_or_else(|| { let gate = self.add_gate( - RandomAccessGate::new_from_config(&self.config, vec_size), + RandomAccessGate::new_from_config(&self.config, bits), vec![], ); (gate, 0) }); // Update `free_random_access` with new values. - if i < RandomAccessGate::::max_num_copies( - self.config.num_routed_wires, - self.config.num_wires, - vec_size, - ) - 1 - { + if i + 1 < RandomAccessGate::::new_from_config(&self.config, bits).num_copies { self.batched_gates .free_random_access - .insert(vec_size, (gate, i + 1)); + .insert(bits, (gate, i + 1)); } else { - self.batched_gates.free_random_access.remove(&vec_size); + self.batched_gates.free_random_access.remove(&bits); } (gate, i) @@ -1031,14 +1026,11 @@ impl, const D: usize> CircuitBuilder { /// `RandomAccessGenerator`s are run. fn fill_random_access_gates(&mut self) { let zero = self.zero(); - for (vec_size, (_, i)) in self.batched_gates.free_random_access.clone() { - let max_copies = RandomAccessGate::::max_num_copies( - self.config.num_routed_wires, - self.config.num_wires, - vec_size, - ); + for (bits, (_, i)) in self.batched_gates.free_random_access.clone() { + let max_copies = + RandomAccessGate::::new_from_config(&self.config, bits).num_copies; for _ in i..max_copies { - self.random_access(zero, zero, vec![zero; vec_size]); + self.random_access(zero, zero, vec![zero; 1 << bits]); } } } From 799ff26e71c03aec171fa0eb94abec8e11e3549e Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 15 Nov 2021 19:46:28 +0100 Subject: [PATCH 104/202] Avoid underflow when checking the length of `terms` --- src/util/reducing.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/util/reducing.rs b/src/util/reducing.rs index 1cddc3e7..f700a6ff 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -107,7 +107,7 @@ impl ReducingFactorTarget { let l = terms.len(); // For small reductions, use an arithmetic gate. - if l - 1 <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops { + if l <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops + 1 { let terms_ext = terms .iter() .map(|&t| builder.convert_to_ext(t)) @@ -162,7 +162,7 @@ impl ReducingFactorTarget { let l = terms.len(); // For small reductions, use an arithmetic gate. - if l - 1 <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops { + if l <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops + 1 { return self.reduce_arithmetic(terms, builder); } From 694b3d3dd55fb5e93cd6e607c0ed4b7470be3e72 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 15 Nov 2021 13:59:49 -0800 Subject: [PATCH 105/202] Recursion in 2^12 gates (#364) For now, we can do shrinking recursion with 93 bits of security. It's not quite as high as we want, but it's close, and I think it makes sense to merge this and treat the 2^12 circuit as our main benchmark, as we continue working to improve security. --- src/plonk/circuit_data.rs | 10 ++++----- src/plonk/recursive_verifier.rs | 37 ++++++++++++++++++++++++--------- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index a26814c3..3197b4e0 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -60,18 +60,18 @@ impl CircuitConfig { pub(crate) fn standard_recursion_config() -> Self { Self { num_wires: 135, - num_routed_wires: 25, - constant_gate_size: 6, + num_routed_wires: 80, + constant_gate_size: 8, use_base_arithmetic_gate: true, - security_bits: 100, + security_bits: 93, rate_bits: 3, num_challenges: 2, zero_knowledge: false, cap_height: 3, fri_config: FriConfig { - proof_of_work_bits: 16, + proof_of_work_bits: 15, reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), - num_query_rounds: 28, + num_query_rounds: 26, }, } } diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index ad049275..147b585e 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -134,6 +134,7 @@ mod tests { use crate::fri::reduction_strategies::FriReductionStrategy; use crate::fri::FriConfig; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; + use crate::gates::noop::NoopGate; use crate::hash::merkle_proofs::MerkleProofTarget; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_data::VerifierOnlyCircuitData; @@ -368,8 +369,8 @@ mod tests { const D: usize = 2; let config = CircuitConfig::standard_recursion_config(); - let (proof, vd, cd) = dummy_proof::(&config, 8_000)?; - let (proof, _vd, cd) = recursive_proof(proof, vd, cd, &config, &config, true, true)?; + let (proof, vd, cd) = dummy_proof::(&config, 4_000)?; + let (proof, _vd, cd) = recursive_proof(proof, vd, cd, &config, &config, None, true, true)?; test_serialization(&proof, &cd)?; Ok(()) @@ -384,9 +385,14 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); - let (proof, vd, cd) = dummy_proof::(&config, 8_000)?; - let (proof, vd, cd) = recursive_proof(proof, vd, cd, &config, &config, false, false)?; - let (proof, _vd, cd) = recursive_proof(proof, vd, cd, &config, &config, true, true)?; + // Start with a degree 2^14 proof, then shrink it to 2^13, then to 2^12. + let (proof, vd, cd) = dummy_proof::(&config, 16_000)?; + assert_eq!(cd.degree_bits, 14); + let (proof, vd, cd) = + recursive_proof(proof, vd, cd, &config, &config, Some(13), false, false)?; + assert_eq!(cd.degree_bits, 13); + let (proof, _vd, cd) = recursive_proof(proof, vd, cd, &config, &config, None, true, true)?; + assert_eq!(cd.degree_bits, 12); test_serialization(&proof, &cd)?; @@ -415,6 +421,7 @@ mod tests { cd, &standard_config, &standard_config, + None, false, false, )?; @@ -437,6 +444,7 @@ mod tests { cd, &standard_config, &high_rate_config, + None, true, true, )?; @@ -460,6 +468,7 @@ mod tests { cd, &high_rate_config, &higher_rate_more_routing_config, + None, true, true, )?; @@ -481,6 +490,7 @@ mod tests { cd, &higher_rate_more_routing_config, &final_config, + None, true, true, )?; @@ -501,16 +511,12 @@ mod tests { CommonCircuitData, )> { let mut builder = CircuitBuilder::::new(config.clone()); - let input = builder.add_virtual_target(); for i in 0..num_dummy_gates { - // Use unique constants to force a new `ArithmeticGate`. - let i_f = F::from_canonical_u64(i); - builder.arithmetic(i_f, i_f, input, input, input); + builder.add_gate(NoopGate, vec![]); } let data = builder.build(); let mut inputs = PartialWitness::new(); - inputs.set_target(input, F::ZERO); let proof = data.prove(inputs)?; data.verify(proof.clone())?; @@ -523,6 +529,7 @@ mod tests { inner_cd: CommonCircuitData, inner_config: &CircuitConfig, config: &CircuitConfig, + min_degree_bits: Option, print_gate_counts: bool, print_timing: bool, ) -> Result<( @@ -549,6 +556,16 @@ mod tests { builder.print_gate_counts(0); } + if let Some(min_degree_bits) = min_degree_bits { + // We don't want to pad all the way up to 2^min_degree_bits, as the builder will add a + // few special gates afterward. So just pad to 2^(min_degree_bits - 1) + 1. Then the + // builder will pad to the next power of two, 2^min_degree_bits. + let min_gates = (1 << (min_degree_bits - 1)) + 1; + for _ in builder.num_gates()..min_gates { + builder.add_gate(NoopGate, vec![]); + } + } + let data = builder.build(); let mut timing = TimingTree::new("prove", Level::Debug); From 4769efa4ddba7595e08b81d6f1d3b4c1d84b00db Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 15 Nov 2021 19:33:03 -0800 Subject: [PATCH 106/202] rename --- src/hash/merkle_proofs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 1ba93cf0..b9381ed1 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -115,7 +115,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn assert_hashes_equal(&mut self, x: HashOutTarget, y: HashOutTarget) { + pub fn connect_hashes(&mut self, x: HashOutTarget, y: HashOutTarget) { for i in 0..4 { self.connect(x.elements[i], y.elements[i]); } From 909a5c23925b9d49d245ed1ea824f397cb131c11 Mon Sep 17 00:00:00 2001 From: Hamish Ivey-Law <426294+unzvfu@users.noreply.github.com> Date: Tue, 16 Nov 2021 21:18:27 +1100 Subject: [PATCH 107/202] Fix all lint warnings (#353) * Suppress warnings about use of unstable compiler features. * Remove unused functions. * Refactor and remove PolynomialCoeffs::new_padded(); fix degree_padded. Note that this fixes a minor mistake in the FFT testing code, where `degree_padded` value was log2 of what it should have been, preventing a testing loop from executing. * Remove divide_by_z_h() and related test functions. * Only compile check_{consistency,test_vectors} when testing. * Move verify() to test module. * Remove unused functions. NB: Changed the config in the gadgets/arithmetic_extension.rs::tests module which may change the test's meaning? * Remove unused import. * Mark GMiMC option as allowed 'dead code'. * Fix missing feature. * Remove unused functions. * cargo fmt * Mark variable as unused. * Revert "Remove unused functions." This reverts commit 99d2357f1c967fd9fd6cac63e1216d929888be72. * Make config functions public. * Mark 'reduce_nonnative()' as dead code for now. * Revert "Move verify() to test module." Refactor to `verify_compressed`. This reverts commit b426e810d033c642f54e25ebc4a8114491df5076. * cargo fmt * Reinstate `verify()` fn on `CompressedProofWithPublicInputs`. --- src/field/fft.rs | 18 +++-- src/field/packed_avx2/packed_prime_field.rs | 15 ---- src/gadgets/nonnative.rs | 1 + src/gadgets/permutation.rs | 1 - src/hash/hashing.rs | 1 + src/hash/merkle_proofs.rs | 1 + src/hash/poseidon.rs | 1 + src/lib.rs | 3 + src/plonk/circuit_data.rs | 20 ++++- src/plonk/plonk_common.rs | 11 --- src/plonk/proof.rs | 10 +-- src/plonk/prover.rs | 2 +- src/polynomial/division.rs | 88 +-------------------- src/polynomial/polynomial.rs | 44 ----------- 14 files changed, 42 insertions(+), 174 deletions(-) diff --git a/src/field/fft.rs b/src/field/fft.rs index 96f19857..17c29184 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -219,13 +219,17 @@ mod tests { #[test] fn fft_and_ifft() { type F = GoldilocksField; - let degree = 200; - let degree_padded = log2_ceil(degree); - let mut coefficients = Vec::new(); - for i in 0..degree { - coefficients.push(F::from_canonical_usize(i * 1337 % 100)); - } - let coefficients = PolynomialCoeffs::new_padded(coefficients); + let degree = 200usize; + let degree_padded = degree.next_power_of_two(); + + // Create a vector of coeffs; the first degree of them are + // "random", the last degree_padded-degree of them are zero. + let coeffs = (0..degree) + .map(|i| F::from_canonical_usize(i * 1337 % 100)) + .chain(std::iter::repeat(F::ZERO).take(degree_padded - degree)) + .collect::>(); + assert_eq!(coeffs.len(), degree_padded); + let coefficients = PolynomialCoeffs { coeffs }; let points = fft(&coefficients); assert_eq!(points, evaluate_naive(&coefficients)); diff --git a/src/field/packed_avx2/packed_prime_field.rs b/src/field/packed_avx2/packed_prime_field.rs index ed87f347..b892da4a 100644 --- a/src/field/packed_avx2/packed_prime_field.rs +++ b/src/field/packed_avx2/packed_prime_field.rs @@ -43,13 +43,6 @@ impl PackedPrimeField { let ptr = (&self.0).as_ptr().cast::<__m256i>(); unsafe { _mm256_loadu_si256(ptr) } } - - /// Addition that assumes x + y < 2^64 + F::ORDER. May return incorrect results if this - /// condition is not met, hence it is marked unsafe. - #[inline] - pub unsafe fn add_canonical_u64(&self, rhs: __m256i) -> Self { - Self::new(add_canonical_u64::(self.get(), rhs)) - } } impl Add for PackedPrimeField { @@ -293,14 +286,6 @@ unsafe fn canonicalize_s(x_s: __m256i) -> __m256i { _mm256_add_epi64(x_s, wrapback_amt) } -/// Addition that assumes x + y < 2^64 + F::ORDER. -#[inline] -unsafe fn add_canonical_u64(x: __m256i, y: __m256i) -> __m256i { - let y_s = shift(y); - let res_s = add_no_canonicalize_64_64s_s::(x, y_s); - shift(res_s) -} - #[inline] unsafe fn add(x: __m256i, y: __m256i) -> __m256i { let y_s = shift(y); diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index fd883e5d..84691421 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -96,6 +96,7 @@ impl, const D: usize> CircuitBuilder { } } + #[allow(dead_code)] fn reduce_nonnative( &mut self, x: &ForeignFieldTarget, diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index c60eda7d..37169514 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -3,7 +3,6 @@ use std::marker::PhantomData; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; -use crate::gates::switch::SwitchGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; diff --git a/src/hash/hashing.rs b/src/hash/hashing.rs index d031ebbb..39b3f51e 100644 --- a/src/hash/hashing.rs +++ b/src/hash/hashing.rs @@ -15,6 +15,7 @@ pub const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; pub(crate) const HASH_FAMILY: HashFamily = HashFamily::Poseidon; pub(crate) enum HashFamily { + #[allow(dead_code)] GMiMC, Poseidon, } diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index b9381ed1..7b766895 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -54,6 +54,7 @@ pub(crate) fn verify_merkle_proof( impl, const D: usize> CircuitBuilder { /// Verifies that the given leaf data is present at the given index in the Merkle tree with the /// given cap. The index is given by it's little-endian bits. + #[cfg(test)] pub(crate) fn verify_merkle_proof( &mut self, leaf_data: Vec, diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 4450e659..03dce1f3 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -627,6 +627,7 @@ where } } +#[cfg(test)] pub(crate) mod test_helpers { use crate::field::field_types::Field; use crate::hash::poseidon::Poseidon; diff --git a/src/lib.rs b/src/lib.rs index 3ed9f747..46db2cf5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,7 @@ +#![allow(incomplete_features)] +#![allow(const_evaluatable_unchecked)] #![feature(asm)] +#![feature(asm_sym)] #![feature(destructuring_assignment)] #![feature(generic_const_exprs)] #![feature(specialization)] diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 3197b4e0..e24a5c6f 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -15,7 +15,7 @@ use crate::hash::merkle_tree::MerkleCap; use crate::iop::generator::WitnessGenerator; use crate::iop::target::Target; use crate::iop::witness::PartialWitness; -use crate::plonk::proof::ProofWithPublicInputs; +use crate::plonk::proof::{CompressedProofWithPublicInputs, ProofWithPublicInputs}; use crate::plonk::prover::prove; use crate::plonk::verifier::verify; use crate::util::marking::MarkedTargets; @@ -57,7 +57,7 @@ impl CircuitConfig { } /// A typical recursion config, without zero-knowledge, targeting ~100 bit security. - pub(crate) fn standard_recursion_config() -> Self { + pub fn standard_recursion_config() -> Self { Self { num_wires: 135, num_routed_wires: 80, @@ -76,7 +76,7 @@ impl CircuitConfig { } } - pub(crate) fn standard_recursion_zk_config() -> Self { + pub fn standard_recursion_zk_config() -> Self { CircuitConfig { zero_knowledge: true, ..Self::standard_recursion_config() @@ -104,6 +104,13 @@ impl, const D: usize> CircuitData { pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { verify(proof_with_pis, &self.verifier_only, &self.common) } + + pub fn verify_compressed( + &self, + compressed_proof_with_pis: CompressedProofWithPublicInputs, + ) -> Result<()> { + compressed_proof_with_pis.verify(&self.verifier_only, &self.common) + } } /// Circuit data required by the prover. This may be thought of as a proving key, although it @@ -140,6 +147,13 @@ impl, const D: usize> VerifierCircuitData { pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { verify(proof_with_pis, &self.verifier_only, &self.common) } + + pub fn verify_compressed( + &self, + compressed_proof_with_pis: CompressedProofWithPublicInputs, + ) -> Result<()> { + compressed_proof_with_pis.verify(&self.verifier_only, &self.common) + } } /// Circuit data required by the prover, but not the verifier. diff --git a/src/plonk/plonk_common.rs b/src/plonk/plonk_common.rs index 6b84886d..5be13740 100644 --- a/src/plonk/plonk_common.rs +++ b/src/plonk/plonk_common.rs @@ -42,17 +42,6 @@ impl PlonkPolynomials { index: 3, blinding: true, }; - - #[cfg(test)] - pub fn polynomials(i: usize) -> PolynomialsIndexBlinding { - match i { - 0 => Self::CONSTANTS_SIGMAS, - 1 => Self::WIRES, - 2 => Self::ZS_PARTIAL_PRODUCTS, - 3 => Self::QUOTIENT, - _ => panic!("There are only 4 sets of polynomials in Plonk."), - } - } } /// Evaluate the polynomial which vanishes on any multiplicative subgroup of a given order `n`. diff --git a/src/plonk/proof.rs b/src/plonk/proof.rs index ce1207cd..815f807d 100644 --- a/src/plonk/proof.rs +++ b/src/plonk/proof.rs @@ -161,12 +161,12 @@ impl, const D: usize> CompressedProofWithPublicInpu ) -> anyhow::Result> { let challenges = self.get_challenges(common_data)?; let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); - let compressed_proof = + let decompressed_proof = self.proof .decompress(&challenges, fri_inferred_elements, common_data); Ok(ProofWithPublicInputs { public_inputs: self.public_inputs, - proof: compressed_proof, + proof: decompressed_proof, }) } @@ -177,13 +177,13 @@ impl, const D: usize> CompressedProofWithPublicInpu ) -> anyhow::Result<()> { let challenges = self.get_challenges(common_data)?; let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); - let compressed_proof = + let decompressed_proof = self.proof .decompress(&challenges, fri_inferred_elements, common_data); verify_with_challenges( ProofWithPublicInputs { public_inputs: self.public_inputs, - proof: compressed_proof, + proof: decompressed_proof, }, challenges, verifier_data, @@ -346,6 +346,6 @@ mod tests { assert_eq!(proof, decompressed_compressed_proof); verify(proof, &data.verifier_only, &data.common)?; - compressed_proof.verify(&data.verifier_only, &data.common) + data.verify_compressed(compressed_proof) } } diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 1dd17cb8..3f8e607d 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -237,7 +237,7 @@ fn wires_permutation_partial_products_and_zs, const let degree = common_data.quotient_degree_factor; let subgroup = &prover_data.subgroup; let k_is = &common_data.k_is; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let all_quotient_chunk_products = subgroup .par_iter() .enumerate() diff --git a/src/polynomial/division.rs b/src/polynomial/division.rs index 6ac38676..b5dad629 100644 --- a/src/polynomial/division.rs +++ b/src/polynomial/division.rs @@ -1,7 +1,6 @@ -use crate::field::fft::{fft, ifft}; use crate::field::field_types::Field; use crate::polynomial::polynomial::PolynomialCoeffs; -use crate::util::{log2_ceil, log2_strict}; +use crate::util::log2_ceil; impl PolynomialCoeffs { /// Polynomial division. @@ -67,63 +66,6 @@ impl PolynomialCoeffs { } } - /// Takes a polynomial `a` in coefficient form, and divides it by `Z_H = X^n - 1`. - /// - /// This assumes `Z_H | a`, otherwise result is meaningless. - pub(crate) fn divide_by_z_h(&self, n: usize) -> PolynomialCoeffs { - let mut a = self.clone(); - - // TODO: Is this special case needed? - if a.coeffs.iter().all(|p| *p == F::ZERO) { - return a; - } - - let g = F::MULTIPLICATIVE_GROUP_GENERATOR; - let mut g_pow = F::ONE; - // Multiply the i-th coefficient of `a` by `g^i`. Then `new_a(w^j) = old_a(g.w^j)`. - a.coeffs.iter_mut().for_each(|x| { - *x *= g_pow; - g_pow *= g; - }); - - let root = F::primitive_root_of_unity(log2_strict(a.len())); - // Equals to the evaluation of `a` on `{g.w^i}`. - let mut a_eval = fft(&a); - // Compute the denominators `1/(g^n.w^(n*i) - 1)` using batch inversion. - let denominator_g = g.exp_u64(n as u64); - let root_n = root.exp_u64(n as u64); - let mut root_pow = F::ONE; - let denominators = (0..a_eval.len()) - .map(|i| { - if i != 0 { - root_pow *= root_n; - } - denominator_g * root_pow - F::ONE - }) - .collect::>(); - let denominators_inv = F::batch_multiplicative_inverse(&denominators); - // Divide every element of `a_eval` by the corresponding denominator. - // Then, `a_eval` is the evaluation of `a/Z_H` on `{g.w^i}`. - a_eval - .values - .iter_mut() - .zip(denominators_inv.iter()) - .for_each(|(x, &d)| { - *x *= d; - }); - // `p` is the interpolating polynomial of `a_eval` on `{w^i}`. - let mut p = ifft(&a_eval); - // We need to scale it by `g^(-i)` to get the interpolating polynomial of `a_eval` on `{g.w^i}`, - // a.k.a `a/Z_H`. - let g_inv = g.inverse(); - let mut g_inv_pow = F::ONE; - p.coeffs.iter_mut().for_each(|x| { - *x *= g_inv_pow; - g_inv_pow *= g_inv; - }); - p - } - /// Let `self=p(X)`, this returns `(p(X)-p(z))/(X-z)` and `p(z)`. /// See https://en.wikipedia.org/wiki/Horner%27s_method pub(crate) fn divide_by_linear(&self, z: F) -> (PolynomialCoeffs, F) { @@ -189,34 +131,6 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::polynomial::polynomial::PolynomialCoeffs; - #[test] - fn zero_div_z_h() { - type F = GoldilocksField; - let zero = PolynomialCoeffs::::zero(16); - let quotient = zero.divide_by_z_h(4); - assert_eq!(quotient, zero); - } - - #[test] - fn division_by_z_h() { - type F = GoldilocksField; - let zero = F::ZERO; - let three = F::from_canonical_u64(3); - let four = F::from_canonical_u64(4); - let five = F::from_canonical_u64(5); - let six = F::from_canonical_u64(6); - - // a(x) = Z_4(x) q(x), where - // a(x) = 3 x^7 + 4 x^6 + 5 x^5 + 6 x^4 - 3 x^3 - 4 x^2 - 5 x - 6 - // Z_4(x) = x^4 - 1 - // q(x) = 3 x^3 + 4 x^2 + 5 x + 6 - let a = PolynomialCoeffs::new(vec![-six, -five, -four, -three, six, five, four, three]); - let q = PolynomialCoeffs::new(vec![six, five, four, three, zero, zero, zero, zero]); - - let computed_q = a.divide_by_z_h(4); - assert_eq!(computed_q, q); - } - #[test] #[ignore] fn test_division_by_linear() { diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index 107d7a7b..a021ecd2 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -24,10 +24,6 @@ impl PolynomialValues { PolynomialValues { values } } - pub(crate) fn zero(len: usize) -> Self { - Self::new(vec![F::ZERO; len]) - } - /// The number of values stored. pub(crate) fn len(&self) -> usize { self.values.len() @@ -88,14 +84,6 @@ impl PolynomialCoeffs { PolynomialCoeffs { coeffs } } - /// Create a new polynomial with its coefficient list padded to the next power of two. - pub(crate) fn new_padded(mut coeffs: Vec) -> Self { - while !coeffs.len().is_power_of_two() { - coeffs.push(F::ZERO); - } - PolynomialCoeffs { coeffs } - } - pub(crate) fn empty() -> Self { Self::new(Vec::new()) } @@ -104,10 +92,6 @@ impl PolynomialCoeffs { Self::new(vec![F::ZERO; len]) } - pub(crate) fn one() -> Self { - Self::new(vec![F::ONE]) - } - pub(crate) fn is_zero(&self) -> bool { self.coeffs.iter().all(|x| x.is_zero()) } @@ -538,34 +522,6 @@ mod tests { } } - #[test] - fn test_division_by_z_h() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let a_deg = rng.gen_range(1..10_000); - let n = rng.gen_range(1..a_deg); - let mut a = PolynomialCoeffs::new(F::rand_vec(a_deg)); - a.trim(); - let z_h = { - let mut z_h_vec = vec![F::ZERO; n + 1]; - z_h_vec[n] = F::ONE; - z_h_vec[0] = F::NEG_ONE; - PolynomialCoeffs::new(z_h_vec) - }; - let m = &a * &z_h; - let now = Instant::now(); - let mut a_test = m.divide_by_z_h(n); - a_test.trim(); - println!("Division time: {:?}", now.elapsed()); - assert_eq!(a, a_test); - } - - #[test] - fn divide_zero_poly_by_z_h() { - let zero_poly = PolynomialCoeffs::::empty(); - zero_poly.divide_by_z_h(16); - } - // Test to see which polynomial division method is faster for divisions of the type // `(X^n - 1)/(X - a) #[test] From 1e66cb9aeec1a1e7d67702566771417118e7798e Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 16 Nov 2021 09:28:58 -0800 Subject: [PATCH 108/202] Route in constants from a ConstantGate (#367) Rather than creating arithmetic gates with potentially unique constants. Should be strictly cheaper, though it only seems to save one gate in practice. --- src/gadgets/arithmetic.rs | 12 ++++++------ src/gadgets/arithmetic_extension.rs | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index a053b761..00dbbe21 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -153,20 +153,20 @@ impl, const D: usize> CircuitBuilder { /// Computes `x + C`. pub fn add_const(&mut self, x: Target, c: F) -> Target { - let one = self.one(); - self.arithmetic(F::ONE, c, one, x, one) + let c = self.constant(c); + self.add(x, c) } /// Computes `C * x`. pub fn mul_const(&mut self, c: F, x: Target) -> Target { - let zero = self.zero(); - self.mul_const_add(c, x, zero) + let c = self.constant(c); + self.mul(c, x) } /// Computes `C * x + y`. pub fn mul_const_add(&mut self, c: F, x: Target, y: Target) -> Target { - let one = self.one(); - self.arithmetic(c, F::ONE, x, one, y) + let c = self.constant(c); + self.mul_add(c, x, y) } /// Computes `x * y - z`. diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 9fbffad3..5716ddc3 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -294,14 +294,14 @@ impl, const D: usize> CircuitBuilder { /// Like `add_const`, but for `ExtensionTarget`s. pub fn add_const_extension(&mut self, x: ExtensionTarget, c: F) -> ExtensionTarget { - let one = self.one_extension(); - self.arithmetic_extension(F::ONE, c, one, x, one) + let c = self.constant_extension(c.into()); + self.add_extension(x, c) } /// Like `mul_const`, but for `ExtensionTarget`s. pub fn mul_const_extension(&mut self, c: F, x: ExtensionTarget) -> ExtensionTarget { - let zero = self.zero_extension(); - self.mul_const_add_extension(c, x, zero) + let c = self.constant_extension(c.into()); + self.mul_extension(c, x) } /// Like `mul_const_add`, but for `ExtensionTarget`s. @@ -311,8 +311,8 @@ impl, const D: usize> CircuitBuilder { x: ExtensionTarget, y: ExtensionTarget, ) -> ExtensionTarget { - let one = self.one_extension(); - self.arithmetic_extension(c, F::ONE, x, one, y) + let c = self.constant_extension(c.into()); + self.mul_add_extension(c, x, y) } /// Like `mul_add`, but for `ExtensionTarget`s. From eb5a60bef110d1a0b2fd731ea9a4dac6e780ac63 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 16 Nov 2021 09:29:14 -0800 Subject: [PATCH 109/202] Allow one BaseSumGate to handle 64 bits (#365) --- src/gadgets/split_join.rs | 23 +++++++++++++++-------- src/gates/base_sum.rs | 3 +-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 39527c6a..72786bd8 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -24,8 +24,7 @@ impl, const D: usize> CircuitBuilder { let mut bits = Vec::with_capacity(num_bits); for &gate in &gates { - let start_limbs = BaseSumGate::<2>::START_LIMBS; - for limb_input in start_limbs..start_limbs + gate_type.num_limbs { + for limb_input in gate_type.limbs() { // `new_unsafe` is safe here because BaseSumGate::<2> forces it to be in `{0, 1}`. bits.push(BoolTarget::new_unsafe(Target::wire(gate, limb_input))); } @@ -35,10 +34,11 @@ impl, const D: usize> CircuitBuilder { } let zero = self.zero(); + let base = F::TWO.exp_u64(gate_type.num_limbs as u64); let mut acc = zero; for &gate in gates.iter().rev() { let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); - acc = self.mul_const_add(F::from_canonical_usize(1 << gate_type.num_limbs), acc, sum); + acc = self.mul_const_add(base, acc, sum); } self.connect(acc, integer); @@ -96,11 +96,18 @@ impl SimpleGenerator for WireSplitGenerator { for &gate in &self.gates { let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); - out_buffer.set_target( - sum, - F::from_canonical_u64(integer_value & ((1 << self.num_limbs) - 1)), - ); - integer_value >>= self.num_limbs; + + // If num_limbs >= 64, we don't need to truncate since `integer_value` is already + // limited to 64 bits, and trying to do so would cause overflow. Hence the conditional. + let mut truncated_value = integer_value; + if self.num_limbs < 64 { + truncated_value = integer_value & ((1 << self.num_limbs) - 1); + integer_value >>= self.num_limbs; + } else { + integer_value = 0; + }; + + out_buffer.set_target(sum, F::from_canonical_u64(truncated_value)); } debug_assert_eq!( diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index 99ee05eb..2ab5345b 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -24,8 +24,7 @@ impl BaseSumGate { } pub fn new_from_config(config: &CircuitConfig) -> Self { - let num_limbs = ((F::ORDER as f64).log(B as f64).floor() as usize) - .min(config.num_routed_wires - Self::START_LIMBS); + let num_limbs = F::bits().min(config.num_routed_wires - Self::START_LIMBS); Self::new(num_limbs) } From 8b710751541380ed64fbfd8b2bc5a3ffaab41a9b Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 16 Nov 2021 09:29:22 -0800 Subject: [PATCH 110/202] Reduce constant_gate_size to 5 (#366) This results in 8 constant polynomials, which means our Merkle tree containing preprocessed polynomials has leaves of size 80 + 8 = 88. A multiple of 8 is efficient in terms of how many gates it takes to hash a leaf. Saves 17 gates. --- src/hash/merkle_proofs.rs | 2 +- src/plonk/circuit_data.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 7b766895..20e2271a 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -94,7 +94,7 @@ impl, const D: usize> CircuitBuilder { proof: &MerkleProofTarget, ) { let zero = self.zero(); - let mut state: HashOutTarget = self.hash_or_noop(leaf_data); + let mut state = self.hash_or_noop(leaf_data); for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { let mut perm_inputs = [zero; SPONGE_WIDTH]; diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index e24a5c6f..6721552c 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -61,7 +61,7 @@ impl CircuitConfig { Self { num_wires: 135, num_routed_wires: 80, - constant_gate_size: 8, + constant_gate_size: 5, use_base_arithmetic_gate: true, security_bits: 93, rate_bits: 3, From eb27a2d2b2937c550a33bd89da13e95c0c734dce Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 16 Nov 2021 22:51:38 -0800 Subject: [PATCH 111/202] warnings --- src/fri/mod.rs | 4 ---- src/gadgets/split_base.rs | 5 +---- src/plonk/recursive_verifier.rs | 4 ++-- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/fri/mod.rs b/src/fri/mod.rs index 2419b06b..bfb2ebfd 100644 --- a/src/fri/mod.rs +++ b/src/fri/mod.rs @@ -36,8 +36,4 @@ impl FriParams { pub(crate) fn max_arity_bits(&self) -> Option { self.reduction_arity_bits.iter().copied().max() } - - pub(crate) fn max_arity(&self) -> Option { - self.max_arity_bits().map(|bits| 1 << bits) - } } diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index d60324ce..ade2ab0c 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -29,10 +29,7 @@ impl, const D: usize> CircuitBuilder { /// Takes an iterator of bits `(b_i)` and returns `sum b_i * 2^i`, i.e., /// the number with little-endian bit representation given by `bits`. - pub(crate) fn le_sum( - &mut self, - mut bits: impl Iterator>, - ) -> Target { + pub(crate) fn le_sum(&mut self, bits: impl Iterator>) -> Target { let bits = bits.map(|b| *b.borrow()).collect_vec(); let num_bits = bits.len(); if num_bits == 0 { diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 147b585e..7e5a91a3 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -511,12 +511,12 @@ mod tests { CommonCircuitData, )> { let mut builder = CircuitBuilder::::new(config.clone()); - for i in 0..num_dummy_gates { + for _ in 0..num_dummy_gates { builder.add_gate(NoopGate, vec![]); } let data = builder.build(); - let mut inputs = PartialWitness::new(); + let inputs = PartialWitness::new(); let proof = data.prove(inputs)?; data.verify(proof.clone())?; From eb15837acb15afd4a19dd39c791698e67a7d07fe Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 16 Nov 2021 22:53:08 -0800 Subject: [PATCH 112/202] tweak logs --- src/gates/gate_tree.rs | 6 ++++-- src/plonk/circuit_builder.rs | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index aaba41c7..704b410a 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -1,4 +1,4 @@ -use log::info; +use log::debug; use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; @@ -86,7 +86,7 @@ impl, const D: usize> Tree> { } } } - info!( + debug!( "Found tree with max degree {} and {} constants wires in {:.4}s.", best_degree, best_num_constants, @@ -221,6 +221,8 @@ impl, const D: usize> Tree> { #[cfg(test)] mod tests { + use log::info; + use super::*; use crate::field::goldilocks_field::GoldilocksField; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 9cd380be..fc9eef76 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -622,7 +622,7 @@ impl, const D: usize> CircuitBuilder { ..=1 << self.config.rate_bits) .min_by_key(|&q| num_partial_products(self.config.num_routed_wires, q).0 + q) .unwrap(); - info!("Quotient degree factor set to: {}.", quotient_degree_factor); + debug!("Quotient degree factor set to: {}.", quotient_degree_factor); let prefixed_gates = PrefixedGate::from_tree(gate_tree); let subgroup = F::two_adic_subgroup(degree_bits); @@ -725,7 +725,7 @@ impl, const D: usize> CircuitBuilder { circuit_digest, }; - info!("Building circuit took {}s", start.elapsed().as_secs_f32()); + debug!("Building circuit took {}s", start.elapsed().as_secs_f32()); CircuitData { prover_only, verifier_only, From 8772073b36a4f5773fdd5aa24c886b7dafa34578 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 17 Nov 2021 08:13:20 -0800 Subject: [PATCH 113/202] Update size-optimized proof test (#368) The 2^12 change made this outdated. We no longer need to shrink degree (since normal recursive proofs are 2^12), so we can simplify a bit. We just boost the rate, then do a size-optimized proof. (Without doing the rate boost first, the final proof would be over 2^12.) Configured for 93 bits security for now, but the PoW settings are low so that'll be easy to increase. ~45kb with current settings. --- src/plonk/recursive_verifier.rs | 55 ++++++++++----------------------- 1 file changed, 16 insertions(+), 39 deletions(-) diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 7e5a91a3..2dc0223b 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -410,11 +410,11 @@ mod tests { let standard_config = CircuitConfig::standard_recursion_config(); - // A dummy proof with degree 2^13. - let (proof, vd, cd) = dummy_proof::(&standard_config, 8_000)?; - assert_eq!(cd.degree_bits, 13); + // An initial dummy proof. + let (proof, vd, cd) = dummy_proof::(&standard_config, 4_000)?; + assert_eq!(cd.degree_bits, 12); - // A standard recursive proof with degree 2^13. + // A standard recursive proof. let (proof, vd, cd) = recursive_proof( proof, vd, @@ -425,15 +425,14 @@ mod tests { false, false, )?; - assert_eq!(cd.degree_bits, 13); + assert_eq!(cd.degree_bits, 12); - // A high-rate recursive proof with degree 2^13, designed to be verifiable with 2^12 - // gates and 48 routed wires. + // A high-rate recursive proof, designed to be verifiable with fewer routed wires. let high_rate_config = CircuitConfig { - rate_bits: 5, + rate_bits: 7, fri_config: FriConfig { - proof_of_work_bits: 20, - num_query_rounds: 16, + proof_of_work_bits: 16, + num_query_rounds: 11, ..standard_config.fri_config.clone() }, ..standard_config @@ -448,47 +447,25 @@ mod tests { true, true, )?; - assert_eq!(cd.degree_bits, 13); - - // A higher-rate recursive proof with degree 2^12, designed to be verifiable with 2^12 - // gates and 28 routed wires. - let higher_rate_more_routing_config = CircuitConfig { - rate_bits: 7, - num_routed_wires: 48, - fri_config: FriConfig { - proof_of_work_bits: 23, - num_query_rounds: 11, - ..standard_config.fri_config.clone() - }, - ..high_rate_config.clone() - }; - let (proof, vd, cd) = recursive_proof( - proof, - vd, - cd, - &high_rate_config, - &higher_rate_more_routing_config, - None, - true, - true, - )?; assert_eq!(cd.degree_bits, 12); - // A final proof of degree 2^12, optimized for size. + // A final proof, optimized for size. let final_config = CircuitConfig { cap_height: 0, - num_routed_wires: 32, + rate_bits: 8, + num_routed_wires: 25, fri_config: FriConfig { + proof_of_work_bits: 21, reduction_strategy: FriReductionStrategy::MinSize(None), - ..higher_rate_more_routing_config.fri_config.clone() + num_query_rounds: 9, }, - ..higher_rate_more_routing_config + ..high_rate_config }; let (proof, _vd, cd) = recursive_proof( proof, vd, cd, - &higher_rate_more_routing_config, + &high_rate_config, &final_config, None, true, From 9b55ff9e816ede79a3630a84086feee5a5e5990a Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 17 Nov 2021 14:43:54 -0800 Subject: [PATCH 114/202] edition = 2021 (#370) * edition = 2021 Doesn't affect anything for us as far as I've noticed. * imports --- Cargo.toml | 2 +- src/field/extension_field/mod.rs | 2 -- src/field/extension_field/target.rs | 1 - src/field/field_types.rs | 1 - src/field/secp256k1.rs | 1 - src/gadgets/arithmetic_extension.rs | 4 ---- src/gadgets/hash.rs | 2 -- src/gates/gmimc.rs | 2 -- src/gates/insertion.rs | 1 - src/gates/interpolation.rs | 1 - src/gates/poseidon.rs | 3 --- src/gates/poseidon_mds.rs | 1 - src/hash/arch/aarch64/poseidon_goldilocks_neon.rs | 1 - src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs | 1 - src/hash/hash_types.rs | 2 -- src/hash/hashing.rs | 2 -- src/hash/merkle_proofs.rs | 2 -- src/hash/poseidon.rs | 2 -- src/iop/challenger.rs | 2 -- src/iop/witness.rs | 1 - src/plonk/circuit_builder.rs | 1 - src/plonk/vars.rs | 1 - src/util/serialization.rs | 2 -- 23 files changed, 1 insertion(+), 37 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4c182fa2..d3f46e35 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0" repository = "https://github.com/mir-protocol/plonky2" keywords = ["cryptography", "SNARK", "FRI"] categories = ["cryptography"] -edition = "2018" +edition = "2021" default-run = "bench_recursion" [dependencies] diff --git a/src/field/extension_field/mod.rs b/src/field/extension_field/mod.rs index 08443386..611c5671 100644 --- a/src/field/extension_field/mod.rs +++ b/src/field/extension_field/mod.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use crate::field::field_types::{Field, PrimeField}; pub mod algebra; diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index ab3c3619..3f9b6684 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -1,4 +1,3 @@ -use std::convert::{TryFrom, TryInto}; use std::ops::Range; use crate::field::extension_field::algebra::ExtensionAlgebra; diff --git a/src/field/field_types.rs b/src/field/field_types.rs index aed654d5..bbb1604e 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::fmt::{Debug, Display}; use std::hash::Hash; use std::iter::{Product, Sum}; diff --git a/src/field/secp256k1.rs b/src/field/secp256k1.rs index 5f8e1b4e..75221a1f 100644 --- a/src/field/secp256k1.rs +++ b/src/field/secp256k1.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::fmt; use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 5716ddc3..7b6535f9 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::FieldExtension; use crate::field::extension_field::{Extendable, OEF}; @@ -531,8 +529,6 @@ pub(crate) struct ExtensionArithmeticOperation, co #[cfg(test)] mod tests { - use std::convert::TryInto; - use anyhow::Result; use crate::field::extension_field::algebra::ExtensionAlgebra; diff --git a/src/gadgets/hash.rs b/src/gadgets/hash.rs index db4cb1e8..2eb93da7 100644 --- a/src/gadgets/hash.rs +++ b/src/gadgets/hash.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; use crate::gates::gmimc::GMiMCGate; diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 8a12df54..1947c851 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -328,8 +328,6 @@ impl + GMiMC, const D: usize, const WIDTH: u #[cfg(test)] mod tests { - use std::convert::TryInto; - use anyhow::Result; use crate::field::field_types::Field; diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index dcc79f05..c55f53a9 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::marker::PhantomData; use std::ops::Range; diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 52dca440..d97eb009 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::marker::PhantomData; use std::ops::Range; diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index ccec2445..ef212a2b 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; @@ -538,8 +537,6 @@ where #[cfg(test)] mod tests { - use std::convert::TryInto; - use anyhow::Result; use crate::field::field_types::Field; diff --git a/src/gates/poseidon_mds.rs b/src/gates/poseidon_mds.rs index a127df68..1abbe71f 100644 --- a/src/gates/poseidon_mds.rs +++ b/src/gates/poseidon_mds.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::marker::PhantomData; use std::ops::Range; diff --git a/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs b/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs index f122e0ef..5d9d9fba 100644 --- a/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs +++ b/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs @@ -1,5 +1,4 @@ use std::arch::aarch64::*; -use std::convert::TryInto; use static_assertions::const_assert; use unroll::unroll_for_loops; diff --git a/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs b/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs index 1df21550..5497ab6c 100644 --- a/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs +++ b/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs @@ -1,5 +1,4 @@ use core::arch::x86_64::*; -use std::convert::TryInto; use std::mem::size_of; use static_assertions::const_assert; diff --git a/src/hash/hash_types.rs b/src/hash/hash_types.rs index eb2f16b0..0fec294b 100644 --- a/src/hash/hash_types.rs +++ b/src/hash/hash_types.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use rand::Rng; use serde::{Deserialize, Serialize}; diff --git a/src/hash/hashing.rs b/src/hash/hashing.rs index 39b3f51e..a4610495 100644 --- a/src/hash/hashing.rs +++ b/src/hash/hashing.rs @@ -1,7 +1,5 @@ //! Concrete instantiation of a hash function. -use std::convert::TryInto; - use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; use crate::hash::hash_types::{HashOut, HashOutTarget}; diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 20e2271a..b5cb3b20 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use anyhow::{ensure, Result}; use serde::{Deserialize, Serialize}; diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 03dce1f3..8572143a 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -1,8 +1,6 @@ //! Implementation of the Poseidon hash function, as described in //! https://eprint.iacr.org/2019/458.pdf -use std::convert::TryInto; - use unroll::unroll_for_loops; use crate::field::extension_field::target::ExtensionTarget; diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index 87b0512d..bd990c6d 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::RichField; diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 0388a6cb..a773c1a9 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::convert::TryInto; use num::{BigUint, FromPrimitive, Zero}; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index fc9eef76..32ea59b6 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -1,6 +1,5 @@ use std::cmp::max; use std::collections::{BTreeMap, HashMap, HashSet}; -use std::convert::TryInto; use std::time::Instant; use log::{debug, info, Level}; diff --git a/src/plonk/vars.rs b/src/plonk/vars.rs index 110aa689..b643b7b7 100644 --- a/src/plonk/vars.rs +++ b/src/plonk/vars.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::ops::Range; use crate::field::extension_field::algebra::ExtensionAlgebra; diff --git a/src/util/serialization.rs b/src/util/serialization.rs index 172b4d67..5ca4f691 100644 --- a/src/util/serialization.rs +++ b/src/util/serialization.rs @@ -1,8 +1,6 @@ use std::collections::HashMap; -use std::convert::TryInto; use std::io::Cursor; use std::io::{Read, Result, Write}; -use std::iter::FromIterator; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::{PrimeField, RichField}; From 2b4bb13ab0a6bfd3e2ed72fe058303ea0cb9564a Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Thu, 18 Nov 2021 23:00:56 -0800 Subject: [PATCH 115/202] Remove total_constraints (#372) It's out of date, and unused now anyway. --- src/plonk/circuit_data.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 6721552c..c2a8d6d0 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -246,11 +246,6 @@ impl, const D: usize> CommonCircuitData { self.quotient_degree_factor * self.degree() } - pub fn total_constraints(&self) -> usize { - // 2 constraints for each Z check. - self.config.num_challenges * 2 + self.num_gate_constraints - } - /// Range of the constants polynomials in the `constants_sigmas_commitment`. pub fn constants_range(&self) -> Range { 0..self.num_constants From 0de408c40fb44a33e2bd046f742c9d6db46a9dc6 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 19 Nov 2021 09:31:06 +0100 Subject: [PATCH 116/202] MulExtensionGate --- src/gadgets/arithmetic_extension.rs | 54 ++++++- src/gates/mod.rs | 1 + src/gates/multiplication_extension.rs | 204 ++++++++++++++++++++++++++ src/plonk/circuit_builder.rs | 50 ++++++- 4 files changed, 305 insertions(+), 4 deletions(-) create mode 100644 src/gates/multiplication_extension.rs diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 7b6535f9..7489ed4c 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -3,6 +3,7 @@ use crate::field::extension_field::FieldExtension; use crate::field::extension_field::{Extendable, OEF}; use crate::field::field_types::{Field, PrimeField, RichField}; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; +use crate::gates::multiplication_extension::MulExtensionGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; @@ -41,8 +42,12 @@ impl, const D: usize> CircuitBuilder { return result; } + let result = if self.target_as_constant_ext(addend) == Some(F::Extension::ZERO) { + self.add_mul_extension_operation(operation) + } else { + self.add_arithmetic_extension_operation(operation) + }; // Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot. - let result = self.add_arithmetic_extension_operation(operation); self.arithmetic_results.insert(operation, result); result } @@ -70,6 +75,22 @@ impl, const D: usize> CircuitBuilder { ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_ith_output(i)) } + fn add_mul_extension_operation( + &mut self, + operation: ExtensionArithmeticOperation, + ) -> ExtensionTarget { + let (gate, i) = self.find_mul_gate(operation.const_0); + let wires_multiplicand_0 = + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_multiplicand_0(i)); + let wires_multiplicand_1 = + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_multiplicand_1(i)); + + self.connect_extension(operation.multiplicand_0, wires_multiplicand_0); + self.connect_extension(operation.multiplicand_1, wires_multiplicand_1); + + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_output(i)) + } + /// Checks for special cases where the value of /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` /// can be determined without adding an `ArithmeticGate`. @@ -556,7 +577,7 @@ mod tests { for (&v, &t) in vs.iter().zip(&ts) { pw.set_extension_target(t, v); } - let mul0 = builder.mul_many_extension(&ts); + // let mul0 = builder.mul_many_extension(&ts); let mul1 = { let mut acc = builder.one_extension(); for &t in &ts { @@ -566,7 +587,7 @@ mod tests { }; let mul2 = builder.constant_extension(vs.into_iter().product()); - builder.connect_extension(mul0, mul1); + // builder.connect_extension(mul0, mul1); builder.connect_extension(mul1, mul2); let data = builder.build(); @@ -633,4 +654,31 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_mul_ext() -> Result<()> { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = FF::rand(); + let y = FF::rand(); + let z = x * y; + + let xt = builder.constant_extension(x); + let yt = builder.constant_extension(y); + let zt = builder.constant_extension(z); + let comp_zt = builder.mul_extension(xt, yt); + builder.connect_extension(zt, comp_zt); + + let data = builder.build(); + let proof = data.prove(pw)?; + + verify(proof, &data.verifier_only, &data.common) + } } diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 93de5e97..369c9ea5 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -14,6 +14,7 @@ pub mod gate_tree; pub mod gmimc; pub mod insertion; pub mod interpolation; +pub mod multiplication_extension; pub mod noop; pub mod poseidon; pub(crate) mod poseidon_mds; diff --git a/src/gates/multiplication_extension.rs b/src/gates/multiplication_extension.rs new file mode 100644 index 00000000..16cc6315 --- /dev/null +++ b/src/gates/multiplication_extension.rs @@ -0,0 +1,204 @@ +use std::ops::Range; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::extension_field::FieldExtension; +use crate::field::field_types::RichField; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config +/// supports enough routed wires, it can support several such operations in one gate. +#[derive(Debug)] +pub struct MulExtensionGate { + /// Number of arithmetic operations performed by an arithmetic gate. + pub num_ops: usize, +} + +impl MulExtensionGate { + pub fn new_from_config(config: &CircuitConfig) -> Self { + Self { + num_ops: Self::num_ops(config), + } + } + + /// Determine the maximum number of operations that can fit in one gate for the given config. + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 3 * D; + config.num_routed_wires / wires_per_op + } + + pub fn wires_ith_multiplicand_0(i: usize) -> Range { + 3 * D * i..3 * D * i + D + } + pub fn wires_ith_multiplicand_1(i: usize) -> Range { + 3 * D * i + D..3 * D * i + 2 * D + } + pub fn wires_ith_output(i: usize) -> Range { + 3 * D * i + 2 * D..3 * D * i + 3 * D + } +} + +impl, const D: usize> Gate for MulExtensionGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let const_0 = vars.local_constants[0]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i)); + let output = vars.get_local_ext_algebra(Self::wires_ith_output(i)); + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0); + + constraints.extend((output - computed_output).to_basefield_array()); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let const_0 = vars.local_constants[0]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i)); + let output = vars.get_local_ext(Self::wires_ith_output(i)); + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0); + + constraints.extend((output - computed_output).to_basefield_array()); + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let const_0 = vars.local_constants[0]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i)); + let output = vars.get_local_ext_algebra(Self::wires_ith_output(i)); + let computed_output = { + let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1); + builder.scalar_mul_ext_algebra(const_0, mul) + }; + + let diff = builder.sub_ext_algebra(output, computed_output); + constraints.extend(diff.to_ext_target_array()); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + local_constants: &[F], + ) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + MulExtensionGenerator { + gate_index, + const_0: local_constants[0], + i, + } + .adapter(), + ); + g + }) + .collect::>() + } + + fn num_wires(&self) -> usize { + self.num_ops * 3 * D + } + + fn num_constants(&self) -> usize { + 1 + } + + fn degree(&self) -> usize { + 3 + } + + fn num_constraints(&self) -> usize { + self.num_ops * D + } +} + +#[derive(Clone, Debug)] +struct MulExtensionGenerator, const D: usize> { + gate_index: usize, + const_0: F, + i: usize, +} + +impl, const D: usize> SimpleGenerator + for MulExtensionGenerator +{ + fn dependencies(&self) -> Vec { + MulExtensionGate::::wires_ith_multiplicand_0(self.i) + .chain(MulExtensionGate::::wires_ith_multiplicand_1(self.i)) + .map(|i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) + }; + + let multiplicand_0 = + extract_extension(MulExtensionGate::::wires_ith_multiplicand_0(self.i)); + let multiplicand_1 = + extract_extension(MulExtensionGate::::wires_ith_multiplicand_1(self.i)); + + let output_target = ExtensionTarget::from_range( + self.gate_index, + MulExtensionGate::::wires_ith_output(self.i), + ); + + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(self.const_0); + + out_buffer.set_extension_target(output_target, computed_output) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::multiplication_extension::MulExtensionGate; + use crate::plonk::circuit_data::CircuitConfig; + + #[test] + fn low_degree() { + let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_low_degree::(gate); + } + + #[test] + fn eval_fns() -> Result<()> { + let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_eval_fns::(gate) + } +} diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 32ea59b6..eda31ece 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -20,6 +20,7 @@ use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; +use crate::gates::multiplication_extension::MulExtensionGate; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; use crate::gates::random_access::RandomAccessGate; @@ -70,7 +71,7 @@ pub struct CircuitBuilder, const D: usize> { marked_targets: Vec>, /// Generators used to generate the witness. - generators: Vec>>, + pub generators: Vec>>, constants_to_targets: HashMap, targets_to_constants: HashMap, @@ -769,6 +770,8 @@ pub struct BatchedGates, const D: usize> { pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, pub(crate) free_base_arithmetic: HashMap<(F, F), (usize, usize)>, + pub(crate) free_mul: HashMap, + /// A map `b -> (g, i)` from `b` bits to an available random access gate of that size with gate /// index `g` and already using `i` random accesses. pub(crate) free_random_access: HashMap, @@ -793,6 +796,7 @@ impl, const D: usize> BatchedGates { Self { free_arithmetic: HashMap::new(), free_base_arithmetic: HashMap::new(), + free_mul: HashMap::new(), free_random_access: HashMap::new(), current_switch_gates: Vec::new(), current_u32_arithmetic_gate: None, @@ -865,6 +869,33 @@ impl, const D: usize> CircuitBuilder { (gate, i) } + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub(crate) fn find_mul_gate(&mut self, const_0: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_mul + .get(&const_0) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + MulExtensionGate::new_from_config(&self.config), + vec![const_0], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < MulExtensionGate::::num_ops(&self.config) - 1 { + self.batched_gates.free_mul.insert(const_0, (gate, i + 1)); + } else { + self.batched_gates.free_mul.remove(&const_0); + } + + (gate, i) + } + /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index /// `g` and the gate's `i`-th random access is available. @@ -1021,6 +1052,22 @@ impl, const D: usize> CircuitBuilder { assert!(self.batched_gates.free_arithmetic.is_empty()); } + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticExtensionGenerator`s are run. + fn fill_mul_gates(&mut self) { + let zero = self.zero_extension(); + for (c0, (_gate, i)) in self.batched_gates.free_mul.clone() { + for _ in i..MulExtensionGate::::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_extension_target(); + self.arithmetic_extension(c0, F::ZERO, dummy, dummy, zero); + self.connect_extension(dummy, zero); + } + } + assert!(self.batched_gates.free_mul.is_empty()); + } + /// Fill the remaining unused random access operations with zeros, so that all /// `RandomAccessGenerator`s are run. fn fill_random_access_gates(&mut self) { @@ -1110,6 +1157,7 @@ impl, const D: usize> CircuitBuilder { fn fill_batched_gates(&mut self) { self.fill_arithmetic_gates(); self.fill_base_arithmetic_gates(); + self.fill_mul_gates(); self.fill_random_access_gates(); self.fill_switch_gates(); self.fill_u32_arithmetic_gates(); From 939acfed96b72fd485cd5713a958ed6a95107448 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 19 Nov 2021 11:14:03 +0100 Subject: [PATCH 117/202] Fix mul_many --- src/gadgets/arithmetic.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 00dbbe21..d8dce53d 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -206,11 +206,11 @@ impl, const D: usize> CircuitBuilder { /// Multiply `n` `Target`s. pub fn mul_many(&mut self, terms: &[Target]) -> Target { - let terms_ext = terms - .iter() - .map(|&t| self.convert_to_ext(t)) - .collect::>(); - self.mul_many_extension(&terms_ext).to_target_array()[0] + let mut product = self.one(); + for &term in terms { + product = self.mul(product, term); + } + product } /// Exponentiate `base` to the power of `2^power_log`. From 90a6ffd77503c3397a1be087e799b11e492dd726 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 19 Nov 2021 11:24:43 +0100 Subject: [PATCH 118/202] Use fold1 in mul_many --- src/gadgets/arithmetic.rs | 12 +++++++----- src/gadgets/arithmetic_extension.rs | 12 +++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index d8dce53d..0931cc88 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,5 +1,7 @@ use std::borrow::Borrow; +use itertools::Itertools; + use crate::field::extension_field::Extendable; use crate::field::field_types::{PrimeField, RichField}; use crate::gates::arithmetic_base::ArithmeticGate; @@ -206,11 +208,11 @@ impl, const D: usize> CircuitBuilder { /// Multiply `n` `Target`s. pub fn mul_many(&mut self, terms: &[Target]) -> Target { - let mut product = self.one(); - for &term in terms { - product = self.mul(product, term); - } - product + terms + .iter() + .copied() + .fold1(|acc, t| self.mul(acc, t)) + .unwrap_or_else(|| self.one()) } /// Exponentiate `base` to the power of `2^power_log`. diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 7489ed4c..e4f05c0f 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -1,3 +1,5 @@ +use itertools::Itertools; + use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::FieldExtension; use crate::field::extension_field::{Extendable, OEF}; @@ -294,11 +296,11 @@ impl, const D: usize> CircuitBuilder { /// Multiply `n` `ExtensionTarget`s. pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let mut product = self.one_extension(); - for &term in terms { - product = self.mul_extension(product, term); - } - product + terms + .iter() + .copied() + .fold1(|acc, t| self.mul_extension(acc, t)) + .unwrap_or_else(|| self.one_extension()) } /// Like `mul_add`, but for `ExtensionTarget`s. From 4f11713c498ad66cb8fb0b2cc77fcdc6ff83fee7 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 19 Nov 2021 11:29:51 +0100 Subject: [PATCH 119/202] Remove useless test --- src/gadgets/arithmetic_extension.rs | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index e4f05c0f..30c49bf4 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -656,31 +656,4 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } - - #[test] - fn test_mul_ext() -> Result<()> { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; - - let config = CircuitConfig::standard_recursion_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = FF::rand(); - let y = FF::rand(); - let z = x * y; - - let xt = builder.constant_extension(x); - let yt = builder.constant_extension(y); - let zt = builder.constant_extension(z); - let comp_zt = builder.mul_extension(xt, yt); - builder.connect_extension(zt, comp_zt); - - let data = builder.build(); - let proof = data.prove(pw)?; - - verify(proof, &data.verifier_only, &data.common) - } } From 22f4c18083b86b1609a5204d5d084a0e9a54fb8a Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 19 Nov 2021 11:48:42 +0100 Subject: [PATCH 120/202] Comments --- src/gadgets/arithmetic_extension.rs | 6 ++++-- src/gates/multiplication_extension.rs | 4 ++-- src/plonk/circuit_builder.rs | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 30c49bf4..cc82be68 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -45,8 +45,10 @@ impl, const D: usize> CircuitBuilder { } let result = if self.target_as_constant_ext(addend) == Some(F::Extension::ZERO) { + // If the addend is zero, we use a multiplication gate. self.add_mul_extension_operation(operation) } else { + // Otherwise, we use an arithmetic gate. self.add_arithmetic_extension_operation(operation) }; // Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot. @@ -579,7 +581,7 @@ mod tests { for (&v, &t) in vs.iter().zip(&ts) { pw.set_extension_target(t, v); } - // let mul0 = builder.mul_many_extension(&ts); + let mul0 = builder.mul_many_extension(&ts); let mul1 = { let mut acc = builder.one_extension(); for &t in &ts { @@ -589,7 +591,7 @@ mod tests { }; let mul2 = builder.constant_extension(vs.into_iter().product()); - // builder.connect_extension(mul0, mul1); + builder.connect_extension(mul0, mul1); builder.connect_extension(mul1, mul2); let data = builder.build(); diff --git a/src/gates/multiplication_extension.rs b/src/gates/multiplication_extension.rs index 16cc6315..4c385b79 100644 --- a/src/gates/multiplication_extension.rs +++ b/src/gates/multiplication_extension.rs @@ -12,11 +12,11 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config +/// A gate which can perform a weighted multiplication, i.e. `result = c0 x y`. If the config /// supports enough routed wires, it can support several such operations in one gate. #[derive(Debug)] pub struct MulExtensionGate { - /// Number of arithmetic operations performed by an arithmetic gate. + /// Number of multiplications performed by the gate. pub num_ops: usize, } diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index eda31ece..730699b9 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -71,7 +71,7 @@ pub struct CircuitBuilder, const D: usize> { marked_targets: Vec>, /// Generators used to generate the witness. - pub generators: Vec>>, + generators: Vec>>, constants_to_targets: HashMap, targets_to_constants: HashMap, From aec88a852877a2773343851399aa07fb1efab61d Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 19 Nov 2021 18:11:14 +0100 Subject: [PATCH 121/202] First try --- src/field/extension_field/algebra.rs | 18 +++++ src/gadgets/polynomial.rs | 20 +++++ src/gates/gate_testing.rs | 3 + src/gates/interpolation.rs | 113 +++++++++++++++++++++------ src/plonk/circuit_builder.rs | 1 + src/plonk/circuit_data.rs | 7 +- src/plonk/prover.rs | 11 +++ src/polynomial/polynomial.rs | 21 +++++ 8 files changed, 167 insertions(+), 27 deletions(-) diff --git a/src/field/extension_field/algebra.rs b/src/field/extension_field/algebra.rs index 21438262..37f70f52 100644 --- a/src/field/extension_field/algebra.rs +++ b/src/field/extension_field/algebra.rs @@ -160,12 +160,30 @@ impl, const D: usize> PolynomialCoeffsAlgebra { .fold(ExtensionAlgebra::ZERO, |acc, &c| acc * x + c) } + pub fn eval_with_powers(&self, powers: &[ExtensionAlgebra]) -> ExtensionAlgebra { + debug_assert_eq!(self.coeffs.len(), powers.len() + 1); + let acc = self.coeffs[0]; + self.coeffs[1..] + .iter() + .zip(powers) + .fold(acc, |acc, (&x, &c)| acc + c * x) + } + pub fn eval_base(&self, x: F) -> ExtensionAlgebra { self.coeffs .iter() .rev() .fold(ExtensionAlgebra::ZERO, |acc, &c| acc.scalar_mul(x) + c) } + + pub fn eval_base_with_powers(&self, powers: &[F]) -> ExtensionAlgebra { + debug_assert_eq!(self.coeffs.len(), powers.len() + 1); + let acc = self.coeffs[0]; + self.coeffs[1..] + .iter() + .zip(powers) + .fold(acc, |acc, (&x, &c)| acc + x.scalar_mul(c)) + } } #[cfg(test)] diff --git a/src/gadgets/polynomial.rs b/src/gadgets/polynomial.rs index 9f631c10..48a5f8b7 100644 --- a/src/gadgets/polynomial.rs +++ b/src/gadgets/polynomial.rs @@ -64,4 +64,24 @@ impl PolynomialCoeffsExtAlgebraTarget { } acc } + pub fn eval_with_powers( + &self, + builder: &mut CircuitBuilder, + powers: &[ExtensionAlgebraTarget], + ) -> ExtensionAlgebraTarget + where + F: RichField + Extendable, + { + debug_assert_eq!(self.0.len(), powers.len() + 1); + let acc = self.0[0]; + self.0[1..] + .iter() + .zip(powers) + .fold(acc, |acc, (&x, &c)| builder.mul_add_ext_algebra(c, x, acc)) + // let mut acc = builder.zero_ext_algebra(); + // for &c in self.0.iter().rev() { + // acc = builder.mul_add_ext_algebra(point, acc, c); + // } + // acc + } } diff --git a/src/gates/gate_testing.rs b/src/gates/gate_testing.rs index 9fe8e835..15c17ba7 100644 --- a/src/gates/gate_testing.rs +++ b/src/gates/gate_testing.rs @@ -51,7 +51,9 @@ pub(crate) fn test_low_degree, G: Gate, const ); let expected_eval_degree = WITNESS_DEGREE * gate.degree(); + dbg!(WITNESS_DEGREE, gate.degree()); + dbg!(&constraint_eval_degrees); assert!( constraint_eval_degrees .iter() @@ -151,6 +153,7 @@ pub(crate) fn test_eval_fns, G: Gate, const D let evals_t = gate.eval_unfiltered_recursively(&mut builder, vars_t); pw.set_extension_targets(&evals_t, &evals); + dbg!(builder.num_gates()); let data = builder.build(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index d97eb009..536c87bb 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -1,10 +1,10 @@ use std::marker::PhantomData; use std::ops::Range; -use crate::field::extension_field::algebra::PolynomialCoeffsAlgebra; +use crate::field::extension_field::algebra::{ExtensionAlgebra, PolynomialCoeffsAlgebra}; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field_types::RichField; +use crate::field::field_types::{Field, RichField}; use crate::field::interpolation::interpolant; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::Gate; @@ -90,9 +90,26 @@ impl, const D: usize> InterpolationGate { start..start + D } + pub fn powers_init(&self, i: usize) -> usize { + debug_assert!(0 < i && i < self.num_points()); + if i == 1 { + return self.wire_shift(); + } + self.start_coeffs() + self.num_points() * D + i + } + + pub fn powers_eval(&self, i: usize) -> Range { + debug_assert!(0 < i && i < self.num_points()); + if i == 1 { + return self.wires_evaluation_point(); + } + let start = self.start_coeffs() + self.num_points() * D + self.num_points() - 1 + i * D; + start..start + D + } + /// End of wire indices, exclusive. fn end(&self) -> usize { - self.start_coeffs() + self.num_points() * D + self.powers_eval(self.num_points() - 1).end } /// The domain of the points we're interpolating. @@ -138,19 +155,34 @@ impl, const D: usize> Gate for InterpolationG let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) - .collect(); - let interpolant = PolynomialCoeffsAlgebra::new(coeffs); + .collect::>(); - let coset = self.coset_ext(vars.local_wires[self.wire_shift()]); - for (i, point) in coset.into_iter().enumerate() { + let mut powers_init = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_init(i)]) + .collect::>(); + powers_init.insert(0, F::Extension::ONE); + let ocoeffs = coeffs + .iter() + .zip(powers_init) + .map(|(&c, p)| c.scalar_mul(p)) + .collect::>(); + let interpolant = PolynomialCoeffsAlgebra::new(coeffs); + let ointerpolant = PolynomialCoeffsAlgebra::new(ocoeffs); + + for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) + .into_iter() + .enumerate() + { let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = interpolant.eval_base(point); + let computed_value = ointerpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); } - let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + let mut evaluation_point_powers = (1..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) + .collect::>(); let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval(evaluation_point); + let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); constraints @@ -161,19 +193,33 @@ impl, const D: usize> Gate for InterpolationG let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext(self.wires_coeff(i))) - .collect(); + .collect::>(); + let mut powers_init = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_init(i)]) + .collect::>(); + powers_init.insert(0, F::ONE); + let ocoeffs = coeffs + .iter() + .zip(powers_init) + .map(|(&c, p)| c.scalar_mul(p)) + .collect::>(); let interpolant = PolynomialCoeffs::new(coeffs); + let ointerpolant = PolynomialCoeffs::new(ocoeffs); - let coset = self.coset(vars.local_wires[self.wire_shift()]); - for (i, point) in coset.into_iter().enumerate() { + for (i, point) in F::two_adic_subgroup(self.subgroup_bits) + .into_iter() + .enumerate() + { let value = vars.get_local_ext(self.wires_value(i)); - let computed_value = interpolant.eval_base(point); + let computed_value = ointerpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); } - let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); + let evaluation_point_powers = (1..self.num_points()) + .map(|i| vars.get_local_ext(self.powers_eval(i))) + .collect::>(); let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval(evaluation_point); + let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); constraints @@ -188,13 +234,26 @@ impl, const D: usize> Gate for InterpolationG let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) - .collect(); + .collect::>(); + let mut powers_init = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_init(i)]) + .collect::>(); + powers_init.insert(0, builder.one_extension()); + let ocoeffs = coeffs + .iter() + .zip(powers_init) + .map(|(&c, p)| builder.scalar_mul_ext_algebra(p, c)) + .collect::>(); let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); + let ointerpolant = PolynomialCoeffsExtAlgebraTarget(ocoeffs); - let coset = self.coset_ext_recursive(builder, vars.local_wires[self.wire_shift()]); - for (i, point) in coset.into_iter().enumerate() { + for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) + .into_iter() + .enumerate() + { let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = interpolant.eval_scalar(builder, point); + let point = builder.constant_extension(point); + let computed_value = ointerpolant.eval_scalar(builder, point); constraints.extend( &builder .sub_ext_algebra(value, computed_value) @@ -202,9 +261,15 @@ impl, const D: usize> Gate for InterpolationG ); } - let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + let evaluation_point_powers = (1..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) + .collect::>(); let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval(builder, evaluation_point); + let computed_evaluation_value = + interpolant.eval_with_powers(builder, &evaluation_point_powers); + // let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + // let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + // let computed_evaluation_value = interpolant.eval(builder, evaluation_point); constraints.extend( &builder .sub_ext_algebra(evaluation_value, computed_evaluation_value) @@ -238,7 +303,7 @@ impl, const D: usize> Gate for InterpolationG fn degree(&self) -> usize { // The highest power of x is `num_points - 1`, and then multiplication by the coefficient // adds 1. - self.num_points() + 2 } fn num_constraints(&self) -> usize { @@ -357,7 +422,7 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(InterpolationGate::new(2)) + test_eval_fns::(InterpolationGate::new(3)) } #[test] diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 32ea59b6..f19d6ae9 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -687,6 +687,7 @@ impl, const D: usize> CircuitBuilder { marked_targets: self.marked_targets, representative_map: forest.parents, fft_root_table: Some(fft_root_table), + instances: self.gate_instances, }; // The HashSet of gates will have a non-deterministic order. When converting to a Vec, we diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index c2a8d6d0..c0ccbc3c 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -9,7 +9,7 @@ use crate::field::field_types::{Field, RichField}; use crate::fri::commitment::PolynomialBatchCommitment; use crate::fri::reduction_strategies::FriReductionStrategy; use crate::fri::{FriConfig, FriParams}; -use crate::gates::gate::PrefixedGate; +use crate::gates::gate::{GateInstance, PrefixedGate}; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::generator::WitnessGenerator; @@ -67,10 +67,10 @@ impl CircuitConfig { rate_bits: 3, num_challenges: 2, zero_knowledge: false, - cap_height: 3, + cap_height: 4, fri_config: FriConfig { proof_of_work_bits: 15, - reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), + reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5), num_query_rounds: 26, }, } @@ -177,6 +177,7 @@ pub(crate) struct ProverOnlyCircuitData, const D: u pub representative_map: Vec, /// Pre-computed roots for faster FFT. pub fft_root_table: Option>, + pub instances: Vec>, } /// Circuit data required by the verifier, but not the prover. diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 3f8e607d..b5175f02 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -66,6 +66,17 @@ pub(crate) fn prove, const D: usize>( .collect() ); + // let rows = (0..degree) + // .map(|i| wires_values.iter().map(|w| w.values[i]).collect::>()) + // .collect::>(); + // for (i, r) in rows.iter().enumerate() { + // let c = rows.iter().filter(|&x| x == r).count(); + // let s = prover_data.instances[i].gate_ref.0.id(); + // if c > 1 && !s.starts_with("Noop") { + // println!("{} {} {}", prover_data.instances[i].gate_ref.0.id(), i, c); + // } + // } + let wires_commitment = timed!( timing, "compute wires commitment", diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index a021ecd2..3c732208 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -120,6 +120,15 @@ impl PolynomialCoeffs { .fold(F::ZERO, |acc, &c| acc * x + c) } + pub fn eval_with_powers(&self, powers: &[F]) -> F { + debug_assert_eq!(self.coeffs.len(), powers.len() + 1); + let acc = self.coeffs[0]; + self.coeffs[1..] + .iter() + .zip(powers) + .fold(acc, |acc, (&x, &c)| acc + c * x) + } + pub fn eval_base(&self, x: F::BaseField) -> F where F: FieldExtension, @@ -130,6 +139,18 @@ impl PolynomialCoeffs { .fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c) } + pub fn eval_base_with_powers(&self, powers: &[F::BaseField]) -> F + where + F: FieldExtension, + { + debug_assert_eq!(self.coeffs.len(), powers.len() + 1); + let acc = self.coeffs[0]; + self.coeffs[1..] + .iter() + .zip(powers) + .fold(acc, |acc, (&x, &c)| acc + x.scalar_mul(c)) + } + pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec { polys.into_iter().map(|p| p.lde(rate_bits)).collect() } From 0d5ba7e755d2aa57ccea019c36edd933871f5478 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 22 Nov 2021 10:44:19 +0100 Subject: [PATCH 122/202] Working recursive test --- src/gates/interpolation.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 536c87bb..4d110522 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -359,8 +359,17 @@ impl, const D: usize> SimpleGenerator F::Extension::from_basefield_array(arr) }; + let wire_shift = get_local_wire(self.gate.wire_shift()); + + for i in 2..self.gate.num_points() { + out_buffer.set_wire( + local_wire(self.gate.powers_init(i)), + wire_shift.exp_u64(i as u64), + ); + } + // Compute the interpolant. - let points = self.gate.coset(get_local_wire(self.gate.wire_shift())); + let points = self.gate.coset(wire_shift); let points = points .into_iter() .enumerate() @@ -374,6 +383,12 @@ impl, const D: usize> SimpleGenerator } let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); + for i in 2..self.gate.num_points() { + out_buffer.set_extension_target( + ExtensionTarget::from_range(self.gate_index, self.gate.powers_eval(i)), + evaluation_point.exp_u64(i as u64), + ); + } let evaluation_value = interpolant.eval(evaluation_point); let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value); From 442c8560b00e3d66c572660d90f896f56a491b96 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 22 Nov 2021 11:16:58 +0100 Subject: [PATCH 123/202] Under 2^12 with 27 query rounds --- src/gates/interpolation.rs | 49 +++++++++++++++++++++++++++++++++++--- src/plonk/circuit_data.rs | 4 ++-- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 4d110522..9f7fa216 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -161,6 +161,10 @@ impl, const D: usize> Gate for InterpolationG .map(|i| vars.local_wires[self.powers_init(i)]) .collect::>(); powers_init.insert(0, F::Extension::ONE); + let wire_shift = powers_init[1]; + for i in 2..self.num_points() { + constraints.push(powers_init[i - 1] * wire_shift - powers_init[i]); + } let ocoeffs = coeffs .iter() .zip(powers_init) @@ -181,6 +185,13 @@ impl, const D: usize> Gate for InterpolationG let mut evaluation_point_powers = (1..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) .collect::>(); + let evaluation_point = evaluation_point_powers[0]; + for i in 1..self.num_points() - 1 { + constraints.extend( + (evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i]) + .to_basefield_array(), + ); + } let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); @@ -198,6 +209,10 @@ impl, const D: usize> Gate for InterpolationG .map(|i| vars.local_wires[self.powers_init(i)]) .collect::>(); powers_init.insert(0, F::ONE); + let wire_shift = powers_init[1]; + for i in 2..self.num_points() { + constraints.push(powers_init[i - 1] * wire_shift - powers_init[i]); + } let ocoeffs = coeffs .iter() .zip(powers_init) @@ -218,6 +233,13 @@ impl, const D: usize> Gate for InterpolationG let evaluation_point_powers = (1..self.num_points()) .map(|i| vars.get_local_ext(self.powers_eval(i))) .collect::>(); + let evaluation_point = evaluation_point_powers[0]; + for i in 1..self.num_points() - 1 { + constraints.extend( + (evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i]) + .to_basefield_array(), + ); + } let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); @@ -239,6 +261,14 @@ impl, const D: usize> Gate for InterpolationG .map(|i| vars.local_wires[self.powers_init(i)]) .collect::>(); powers_init.insert(0, builder.one_extension()); + let wire_shift = powers_init[1]; + for i in 2..self.num_points() { + constraints.push(builder.mul_sub_extension( + powers_init[i - 1], + wire_shift, + powers_init[i], + )); + } let ocoeffs = coeffs .iter() .zip(powers_init) @@ -264,6 +294,18 @@ impl, const D: usize> Gate for InterpolationG let evaluation_point_powers = (1..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) .collect::>(); + let evaluation_point = evaluation_point_powers[0]; + for i in 1..self.num_points() - 1 { + let neg_one_ext = builder.neg_one_extension(); + let neg_new_power = + builder.scalar_mul_ext_algebra(neg_one_ext, evaluation_point_powers[i]); + let constraint = builder.mul_add_ext_algebra( + evaluation_point, + evaluation_point_powers[i - 1], + neg_new_power, + ); + constraints.extend(constraint.to_ext_target_array()); + } let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval_with_powers(builder, &evaluation_point_powers); @@ -307,9 +349,10 @@ impl, const D: usize> Gate for InterpolationG } fn num_constraints(&self) -> usize { - // num_points * D constraints to check for consistency between the coefficients and the - // point-value pairs, plus D constraints for the evaluation value. - self.num_points() * D + D + // `num_points * D` constraints to check for consistency between the coefficients and the + // point-value pairs, plus `D` constraints for the evaluation value, plus `(D+1)*(num_points-2)` + // to check power constraints for evaluation point and wire shift. + self.num_points() * D + D + (D + 1) * (self.num_points() - 2) } } diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index c0ccbc3c..81d8b09f 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -63,7 +63,7 @@ impl CircuitConfig { num_routed_wires: 80, constant_gate_size: 5, use_base_arithmetic_gate: true, - security_bits: 93, + security_bits: 96, rate_bits: 3, num_challenges: 2, zero_knowledge: false, @@ -71,7 +71,7 @@ impl CircuitConfig { fri_config: FriConfig { proof_of_work_bits: 15, reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5), - num_query_rounds: 26, + num_query_rounds: 27, }, } } From 8522026c36810ac3a242873a2877130a0ef8d7e6 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 22 Nov 2021 11:39:56 +0100 Subject: [PATCH 124/202] Change file structure --- src/gates/interpolation.rs | 177 ++------- src/gates/low_degree_interpolation.rs | 525 ++++++++++++++++++++++++++ src/gates/mod.rs | 1 + 3 files changed, 553 insertions(+), 150 deletions(-) create mode 100644 src/gates/low_degree_interpolation.rs diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 9f7fa216..d97eb009 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -1,10 +1,10 @@ use std::marker::PhantomData; use std::ops::Range; -use crate::field::extension_field::algebra::{ExtensionAlgebra, PolynomialCoeffsAlgebra}; +use crate::field::extension_field::algebra::PolynomialCoeffsAlgebra; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field_types::{Field, RichField}; +use crate::field::field_types::RichField; use crate::field::interpolation::interpolant; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::Gate; @@ -90,26 +90,9 @@ impl, const D: usize> InterpolationGate { start..start + D } - pub fn powers_init(&self, i: usize) -> usize { - debug_assert!(0 < i && i < self.num_points()); - if i == 1 { - return self.wire_shift(); - } - self.start_coeffs() + self.num_points() * D + i - } - - pub fn powers_eval(&self, i: usize) -> Range { - debug_assert!(0 < i && i < self.num_points()); - if i == 1 { - return self.wires_evaluation_point(); - } - let start = self.start_coeffs() + self.num_points() * D + self.num_points() - 1 + i * D; - start..start + D - } - /// End of wire indices, exclusive. fn end(&self) -> usize { - self.powers_eval(self.num_points() - 1).end + self.start_coeffs() + self.num_points() * D } /// The domain of the points we're interpolating. @@ -155,45 +138,19 @@ impl, const D: usize> Gate for InterpolationG let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) - .collect::>(); - - let mut powers_init = (1..self.num_points()) - .map(|i| vars.local_wires[self.powers_init(i)]) - .collect::>(); - powers_init.insert(0, F::Extension::ONE); - let wire_shift = powers_init[1]; - for i in 2..self.num_points() { - constraints.push(powers_init[i - 1] * wire_shift - powers_init[i]); - } - let ocoeffs = coeffs - .iter() - .zip(powers_init) - .map(|(&c, p)| c.scalar_mul(p)) - .collect::>(); + .collect(); let interpolant = PolynomialCoeffsAlgebra::new(coeffs); - let ointerpolant = PolynomialCoeffsAlgebra::new(ocoeffs); - for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) - .into_iter() - .enumerate() - { + let coset = self.coset_ext(vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = ointerpolant.eval_base(point); + let computed_value = interpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); } - let mut evaluation_point_powers = (1..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) - .collect::>(); - let evaluation_point = evaluation_point_powers[0]; - for i in 1..self.num_points() - 1 { - constraints.extend( - (evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i]) - .to_basefield_array(), - ); - } + let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); + let computed_evaluation_value = interpolant.eval(evaluation_point); constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); constraints @@ -204,44 +161,19 @@ impl, const D: usize> Gate for InterpolationG let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext(self.wires_coeff(i))) - .collect::>(); - let mut powers_init = (1..self.num_points()) - .map(|i| vars.local_wires[self.powers_init(i)]) - .collect::>(); - powers_init.insert(0, F::ONE); - let wire_shift = powers_init[1]; - for i in 2..self.num_points() { - constraints.push(powers_init[i - 1] * wire_shift - powers_init[i]); - } - let ocoeffs = coeffs - .iter() - .zip(powers_init) - .map(|(&c, p)| c.scalar_mul(p)) - .collect::>(); + .collect(); let interpolant = PolynomialCoeffs::new(coeffs); - let ointerpolant = PolynomialCoeffs::new(ocoeffs); - for (i, point) in F::two_adic_subgroup(self.subgroup_bits) - .into_iter() - .enumerate() - { + let coset = self.coset(vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { let value = vars.get_local_ext(self.wires_value(i)); - let computed_value = ointerpolant.eval_base(point); + let computed_value = interpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); } - let evaluation_point_powers = (1..self.num_points()) - .map(|i| vars.get_local_ext(self.powers_eval(i))) - .collect::>(); - let evaluation_point = evaluation_point_powers[0]; - for i in 1..self.num_points() - 1 { - constraints.extend( - (evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i]) - .to_basefield_array(), - ); - } + let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); + let computed_evaluation_value = interpolant.eval(evaluation_point); constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); constraints @@ -256,34 +188,13 @@ impl, const D: usize> Gate for InterpolationG let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) - .collect::>(); - let mut powers_init = (1..self.num_points()) - .map(|i| vars.local_wires[self.powers_init(i)]) - .collect::>(); - powers_init.insert(0, builder.one_extension()); - let wire_shift = powers_init[1]; - for i in 2..self.num_points() { - constraints.push(builder.mul_sub_extension( - powers_init[i - 1], - wire_shift, - powers_init[i], - )); - } - let ocoeffs = coeffs - .iter() - .zip(powers_init) - .map(|(&c, p)| builder.scalar_mul_ext_algebra(p, c)) - .collect::>(); + .collect(); let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); - let ointerpolant = PolynomialCoeffsExtAlgebraTarget(ocoeffs); - for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) - .into_iter() - .enumerate() - { + let coset = self.coset_ext_recursive(builder, vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { let value = vars.get_local_ext_algebra(self.wires_value(i)); - let point = builder.constant_extension(point); - let computed_value = ointerpolant.eval_scalar(builder, point); + let computed_value = interpolant.eval_scalar(builder, point); constraints.extend( &builder .sub_ext_algebra(value, computed_value) @@ -291,27 +202,9 @@ impl, const D: usize> Gate for InterpolationG ); } - let evaluation_point_powers = (1..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) - .collect::>(); - let evaluation_point = evaluation_point_powers[0]; - for i in 1..self.num_points() - 1 { - let neg_one_ext = builder.neg_one_extension(); - let neg_new_power = - builder.scalar_mul_ext_algebra(neg_one_ext, evaluation_point_powers[i]); - let constraint = builder.mul_add_ext_algebra( - evaluation_point, - evaluation_point_powers[i - 1], - neg_new_power, - ); - constraints.extend(constraint.to_ext_target_array()); - } + let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - let computed_evaluation_value = - interpolant.eval_with_powers(builder, &evaluation_point_powers); - // let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); - // let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - // let computed_evaluation_value = interpolant.eval(builder, evaluation_point); + let computed_evaluation_value = interpolant.eval(builder, evaluation_point); constraints.extend( &builder .sub_ext_algebra(evaluation_value, computed_evaluation_value) @@ -345,14 +238,13 @@ impl, const D: usize> Gate for InterpolationG fn degree(&self) -> usize { // The highest power of x is `num_points - 1`, and then multiplication by the coefficient // adds 1. - 2 + self.num_points() } fn num_constraints(&self) -> usize { - // `num_points * D` constraints to check for consistency between the coefficients and the - // point-value pairs, plus `D` constraints for the evaluation value, plus `(D+1)*(num_points-2)` - // to check power constraints for evaluation point and wire shift. - self.num_points() * D + D + (D + 1) * (self.num_points() - 2) + // num_points * D constraints to check for consistency between the coefficients and the + // point-value pairs, plus D constraints for the evaluation value. + self.num_points() * D + D } } @@ -402,17 +294,8 @@ impl, const D: usize> SimpleGenerator F::Extension::from_basefield_array(arr) }; - let wire_shift = get_local_wire(self.gate.wire_shift()); - - for i in 2..self.gate.num_points() { - out_buffer.set_wire( - local_wire(self.gate.powers_init(i)), - wire_shift.exp_u64(i as u64), - ); - } - // Compute the interpolant. - let points = self.gate.coset(wire_shift); + let points = self.gate.coset(get_local_wire(self.gate.wire_shift())); let points = points .into_iter() .enumerate() @@ -426,12 +309,6 @@ impl, const D: usize> SimpleGenerator } let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); - for i in 2..self.gate.num_points() { - out_buffer.set_extension_target( - ExtensionTarget::from_range(self.gate_index, self.gate.powers_eval(i)), - evaluation_point.exp_u64(i as u64), - ); - } let evaluation_value = interpolant.eval(evaluation_point); let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value); @@ -480,7 +357,7 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(InterpolationGate::new(3)) + test_eval_fns::(InterpolationGate::new(2)) } #[test] diff --git a/src/gates/low_degree_interpolation.rs b/src/gates/low_degree_interpolation.rs new file mode 100644 index 00000000..95b168f7 --- /dev/null +++ b/src/gates/low_degree_interpolation.rs @@ -0,0 +1,525 @@ +use std::marker::PhantomData; +use std::ops::Range; + +use crate::field::extension_field::algebra::{ExtensionAlgebra, PolynomialCoeffsAlgebra}; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::{Extendable, FieldExtension}; +use crate::field::field_types::{Field, RichField}; +use crate::field::interpolation::interpolant; +use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::polynomial::polynomial::PolynomialCoeffs; + +/// Interpolates a polynomial, whose points are a (base field) coset of the multiplicative subgroup +/// with the given size, and whose values are extension field elements, given by input wires. +/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. +#[derive(Clone, Debug)] +pub(crate) struct LowDegreeInterpolationGate, const D: usize> { + pub subgroup_bits: usize, + _phantom: PhantomData, +} + +impl, const D: usize> LowDegreeInterpolationGate { + pub fn new(subgroup_bits: usize) -> Self { + Self { + subgroup_bits, + _phantom: PhantomData, + } + } + + fn num_points(&self) -> usize { + 1 << self.subgroup_bits + } + + /// Wire index of the coset shift. + pub fn wire_shift(&self) -> usize { + 0 + } + + fn start_values(&self) -> usize { + 1 + } + + /// Wire indices of the `i`th interpolant value. + pub fn wires_value(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_values() + i * D; + start..start + D + } + + fn start_evaluation_point(&self) -> usize { + self.start_values() + self.num_points() * D + } + + /// Wire indices of the point to evaluate the interpolant at. + pub fn wires_evaluation_point(&self) -> Range { + let start = self.start_evaluation_point(); + start..start + D + } + + fn start_evaluation_value(&self) -> usize { + self.start_evaluation_point() + D + } + + /// Wire indices of the interpolated value. + pub fn wires_evaluation_value(&self) -> Range { + let start = self.start_evaluation_value(); + start..start + D + } + + fn start_coeffs(&self) -> usize { + self.start_evaluation_value() + D + } + + /// The number of routed wires required in the typical usage of this gate, where the points to + /// interpolate, the evaluation point, and the corresponding value are all routed. + pub(crate) fn num_routed_wires(&self) -> usize { + self.start_coeffs() + } + + /// Wire indices of the interpolant's `i`th coefficient. + pub fn wires_coeff(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_coeffs() + i * D; + start..start + D + } + + fn end_coeffs(&self) -> usize { + self.start_coeffs() + D * self.num_points() + } + + pub fn powers_init(&self, i: usize) -> usize { + debug_assert!(0 < i && i < self.num_points()); + if i == 1 { + return self.wire_shift(); + } + self.end_coeffs() + i - 2 + } + + pub fn powers_eval(&self, i: usize) -> Range { + debug_assert!(0 < i && i < self.num_points()); + if i == 1 { + return self.wires_evaluation_point(); + } + let start = self.end_coeffs() + self.num_points() - 2 + (i - 2) * D; + start..start + D + } + + /// End of wire indices, exclusive. + fn end(&self) -> usize { + self.powers_eval(self.num_points() - 1).end + } + + /// The domain of the points we're interpolating. + fn coset(&self, shift: F) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. + g.powers().take(size).map(move |x| x * shift) + } + + /// The domain of the points we're interpolating. + fn coset_ext(&self, shift: F::Extension) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + g.powers().take(size).map(move |x| shift.scalar_mul(x)) + } + + /// The domain of the points we're interpolating. + fn coset_ext_recursive( + &self, + builder: &mut CircuitBuilder, + shift: ExtensionTarget, + ) -> Vec> { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + g.powers() + .take(size) + .map(move |x| { + let subgroup_element = builder.constant(x.into()); + builder.scalar_mul_ext(subgroup_element, shift) + }) + .collect() + } +} + +impl, const D: usize> Gate for LowDegreeInterpolationGate { + fn id(&self) -> String { + format!("{:?}", self, D) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let coeffs = (0..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) + .collect::>(); + + let mut powers_init = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_init(i)]) + .collect::>(); + powers_init.insert(0, F::Extension::ONE); + let wire_shift = powers_init[1]; + for i in 2..self.num_points() { + constraints.push(powers_init[i - 1] * wire_shift - powers_init[i]); + } + let ocoeffs = coeffs + .iter() + .zip(powers_init) + .map(|(&c, p)| c.scalar_mul(p)) + .collect::>(); + let interpolant = PolynomialCoeffsAlgebra::new(coeffs); + let ointerpolant = PolynomialCoeffsAlgebra::new(ocoeffs); + + for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) + .into_iter() + .enumerate() + { + let value = vars.get_local_ext_algebra(self.wires_value(i)); + let computed_value = ointerpolant.eval_base(point); + constraints.extend(&(value - computed_value).to_basefield_array()); + } + + let mut evaluation_point_powers = (1..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) + .collect::>(); + let evaluation_point = evaluation_point_powers[0]; + for i in 1..self.num_points() - 1 { + constraints.extend( + (evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i]) + .to_basefield_array(), + ); + } + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); + constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let coeffs = (0..self.num_points()) + .map(|i| vars.get_local_ext(self.wires_coeff(i))) + .collect::>(); + let mut powers_init = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_init(i)]) + .collect::>(); + powers_init.insert(0, F::ONE); + let wire_shift = powers_init[1]; + for i in 2..self.num_points() { + constraints.push(powers_init[i - 1] * wire_shift - powers_init[i]); + } + let ocoeffs = coeffs + .iter() + .zip(powers_init) + .map(|(&c, p)| c.scalar_mul(p)) + .collect::>(); + let interpolant = PolynomialCoeffs::new(coeffs); + let ointerpolant = PolynomialCoeffs::new(ocoeffs); + + for (i, point) in F::two_adic_subgroup(self.subgroup_bits) + .into_iter() + .enumerate() + { + let value = vars.get_local_ext(self.wires_value(i)); + let computed_value = ointerpolant.eval_base(point); + constraints.extend(&(value - computed_value).to_basefield_array()); + } + + let evaluation_point_powers = (1..self.num_points()) + .map(|i| vars.get_local_ext(self.powers_eval(i))) + .collect::>(); + let evaluation_point = evaluation_point_powers[0]; + for i in 1..self.num_points() - 1 { + constraints.extend( + (evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i]) + .to_basefield_array(), + ); + } + let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); + let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); + constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let coeffs = (0..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) + .collect::>(); + let mut powers_init = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_init(i)]) + .collect::>(); + powers_init.insert(0, builder.one_extension()); + let wire_shift = powers_init[1]; + for i in 2..self.num_points() { + constraints.push(builder.mul_sub_extension( + powers_init[i - 1], + wire_shift, + powers_init[i], + )); + } + let ocoeffs = coeffs + .iter() + .zip(powers_init) + .map(|(&c, p)| builder.scalar_mul_ext_algebra(p, c)) + .collect::>(); + let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); + let ointerpolant = PolynomialCoeffsExtAlgebraTarget(ocoeffs); + + for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) + .into_iter() + .enumerate() + { + let value = vars.get_local_ext_algebra(self.wires_value(i)); + let point = builder.constant_extension(point); + let computed_value = ointerpolant.eval_scalar(builder, point); + constraints.extend( + &builder + .sub_ext_algebra(value, computed_value) + .to_ext_target_array(), + ); + } + + let evaluation_point_powers = (1..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) + .collect::>(); + let evaluation_point = evaluation_point_powers[0]; + for i in 1..self.num_points() - 1 { + let neg_one_ext = builder.neg_one_extension(); + let neg_new_power = + builder.scalar_mul_ext_algebra(neg_one_ext, evaluation_point_powers[i]); + let constraint = builder.mul_add_ext_algebra( + evaluation_point, + evaluation_point_powers[i - 1], + neg_new_power, + ); + constraints.extend(constraint.to_ext_target_array()); + } + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + let computed_evaluation_value = + interpolant.eval_with_powers(builder, &evaluation_point_powers); + // let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + // let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + // let computed_evaluation_value = interpolant.eval(builder, evaluation_point); + constraints.extend( + &builder + .sub_ext_algebra(evaluation_value, computed_evaluation_value) + .to_ext_target_array(), + ); + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + let gen = InterpolationGenerator:: { + gate_index, + gate: self.clone(), + _phantom: PhantomData, + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.end() + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + // The highest power of x is `num_points - 1`, and then multiplication by the coefficient + // adds 1. + 2 + } + + fn num_constraints(&self) -> usize { + // `num_points * D` constraints to check for consistency between the coefficients and the + // point-value pairs, plus `D` constraints for the evaluation value, plus `(D+1)*(num_points-2)` + // to check power constraints for evaluation point and wire shift. + self.num_points() * D + D + (D + 1) * (self.num_points() - 2) + } +} + +#[derive(Debug)] +struct InterpolationGenerator, const D: usize> { + gate_index: usize, + gate: LowDegreeInterpolationGate, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for InterpolationGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |input| { + Target::Wire(Wire { + gate: self.gate_index, + input, + }) + }; + + let local_targets = |inputs: Range| inputs.map(local_target); + + let num_points = self.gate.num_points(); + let mut deps = Vec::with_capacity(1 + D + num_points * D); + + deps.push(local_target(self.gate.wire_shift())); + deps.extend(local_targets(self.gate.wires_evaluation_point())); + for i in 0..num_points { + deps.extend(local_targets(self.gate.wires_value(i))); + } + deps + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let get_local_wire = |input| witness.get_wire(local_wire(input)); + + let get_local_ext = |wire_range: Range| { + debug_assert_eq!(wire_range.len(), D); + let values = wire_range.map(get_local_wire).collect::>(); + let arr = values.try_into().unwrap(); + F::Extension::from_basefield_array(arr) + }; + + let wire_shift = get_local_wire(self.gate.wire_shift()); + + for i in 2..self.gate.num_points() { + out_buffer.set_wire( + local_wire(self.gate.powers_init(i)), + wire_shift.exp_u64(i as u64), + ); + } + + // Compute the interpolant. + let points = self.gate.coset(wire_shift); + let points = points + .into_iter() + .enumerate() + .map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i)))) + .collect::>(); + let interpolant = interpolant(&points); + + for (i, &coeff) in interpolant.coeffs.iter().enumerate() { + let wires = self.gate.wires_coeff(i).map(local_wire); + out_buffer.set_ext_wires(wires, coeff); + } + + let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); + for i in 2..self.gate.num_points() { + out_buffer.set_extension_target( + ExtensionTarget::from_range(self.gate_index, self.gate.powers_eval(i)), + evaluation_point.exp_u64(i as u64), + ); + } + let evaluation_value = interpolant.eval(evaluation_point); + let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); + out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value); + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + + use crate::field::extension_field::quadratic::QuadraticExtension; + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; + use crate::hash::hash_types::HashOut; + use crate::plonk::vars::EvaluationVars; + use crate::polynomial::polynomial::PolynomialCoeffs; + + #[test] + fn low_degree() { + test_low_degree::(LowDegreeInterpolationGate::new(4)); + } + + #[test] + fn eval_fns() -> Result<()> { + test_eval_fns::(LowDegreeInterpolationGate::new(4)) + } + + #[test] + fn test_gate_constraint() { + type F = GoldilocksField; + type FF = QuadraticExtension; + const D: usize = 2; + + /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. + fn get_wires( + gate: &LowDegreeInterpolationGate, + shift: F, + coeffs: PolynomialCoeffs, + eval_point: FF, + ) -> Vec { + let points = gate.coset(shift); + let mut v = vec![shift]; + for x in points { + v.extend(coeffs.eval(x.into()).0); + } + v.extend(eval_point.0); + v.extend(coeffs.eval(eval_point).0); + for i in 0..coeffs.len() { + v.extend(coeffs.coeffs[i].0); + } + v.extend(shift.powers().skip(2).take(gate.num_points() - 2)); + v.extend( + eval_point + .powers() + .skip(2) + .take(gate.num_points() - 2) + .flat_map(|ff| ff.0), + ); + v.iter().map(|&x| x.into()).collect::>() + } + + // Get a working row for LowDegreeInterpolationGate. + let subgroup_bits = 4; + let shift = F::rand(); + let coeffs = PolynomialCoeffs::new(FF::rand_vec(1 << subgroup_bits)); + let eval_point = FF::rand(); + let gate = LowDegreeInterpolationGate::::new(subgroup_bits); + dbg!(gate.end_coeffs()); + dbg!(gate.powers_eval(15)); + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(&gate, shift, coeffs, eval_point), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 93de5e97..1663fdd0 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -14,6 +14,7 @@ pub mod gate_tree; pub mod gmimc; pub mod insertion; pub mod interpolation; +pub mod low_degree_interpolation; pub mod noop; pub mod poseidon; pub(crate) mod poseidon_mds; From 6aaea002edf5607c2d380fe9ab48c52d575730a6 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 22 Nov 2021 16:10:14 +0100 Subject: [PATCH 125/202] Choose between high- and low-degree interpolation gate depending on the arity --- src/fri/recursive_verifier.rs | 24 +++- src/gadgets/interpolation.rs | 101 +++++++++++++-- src/gates/gate_tree.rs | 5 +- src/gates/interpolation.rs | 158 ++++++++++++----------- src/gates/low_degree_interpolation.rs | 176 ++++++++++++-------------- src/plonk/circuit_builder.rs | 1 - src/plonk/circuit_data.rs | 3 +- 7 files changed, 280 insertions(+), 188 deletions(-) diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index a08dd99e..f543b12b 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -3,8 +3,10 @@ use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::fri::proof::{FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget}; use crate::fri::FriConfig; +use crate::gadgets::interpolation::InterpolationGate; use crate::gates::gate::Gate; -use crate::gates::interpolation::InterpolationGate; +use crate::gates::interpolation::HighDegreeInterpolationGate; +use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::gates::random_access::RandomAccessGate; use crate::hash::hash_types::MerkleCapTarget; use crate::iop::challenger::RecursiveChallenger; @@ -27,6 +29,7 @@ impl, const D: usize> CircuitBuilder { arity_bits: usize, evals: &[ExtensionTarget], beta: ExtensionTarget, + common_data: &CommonCircuitData, ) -> ExtensionTarget { let arity = 1 << arity_bits; debug_assert_eq!(evals.len(), arity); @@ -43,7 +46,21 @@ impl, const D: usize> CircuitBuilder { let coset_start = self.mul(start, x); // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. - self.interpolate_coset(arity_bits, coset_start, &evals, beta) + if 1 << arity_bits > common_data.quotient_degree_factor { + self.interpolate_coset::>( + arity_bits, + coset_start, + &evals, + beta, + ) + } else { + self.interpolate_coset::>( + arity_bits, + coset_start, + &evals, + beta, + ) + } } /// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check @@ -54,7 +71,7 @@ impl, const D: usize> CircuitBuilder { &self.config, max_fri_arity_bits.max(self.config.cap_height), ); - let interpolation_gate = InterpolationGate::::new(max_fri_arity_bits); + let interpolation_gate = HighDegreeInterpolationGate::::new(max_fri_arity_bits); let min_wires = random_access .num_wires() @@ -379,6 +396,7 @@ impl, const D: usize> CircuitBuilder { arity_bits, evals, betas[i], + &common_data ) ); diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 09b6329b..5d23b65b 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -1,23 +1,90 @@ +use std::ops::Range; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; -use crate::gates::interpolation::InterpolationGate; +use crate::gates::gate::Gate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; +pub(crate) trait InterpolationGate, const D: usize>: + Gate + Copy +{ + fn new(subgroup_bits: usize) -> Self; + + fn num_points(&self) -> usize; + + /// Wire index of the coset shift. + fn wire_shift(&self) -> usize { + 0 + } + + fn start_values(&self) -> usize { + 1 + } + + /// Wire indices of the `i`th interpolant value. + fn wires_value(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_values() + i * D; + start..start + D + } + + fn start_evaluation_point(&self) -> usize { + self.start_values() + self.num_points() * D + } + + /// Wire indices of the point to evaluate the interpolant at. + fn wires_evaluation_point(&self) -> Range { + let start = self.start_evaluation_point(); + start..start + D + } + + fn start_evaluation_value(&self) -> usize { + self.start_evaluation_point() + D + } + + /// Wire indices of the interpolated value. + fn wires_evaluation_value(&self) -> Range { + let start = self.start_evaluation_value(); + start..start + D + } + + fn start_coeffs(&self) -> usize { + self.start_evaluation_value() + D + } + + /// The number of routed wires required in the typical usage of this gate, where the points to + /// interpolate, the evaluation point, and the corresponding value are all routed. + fn num_routed_wires(&self) -> usize { + self.start_coeffs() + } + + /// Wire indices of the interpolant's `i`th coefficient. + fn wires_coeff(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_coeffs() + i * D; + start..start + D + } + + fn end_coeffs(&self) -> usize { + self.start_coeffs() + D * self.num_points() + } +} + impl, const D: usize> CircuitBuilder { /// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the /// given size, and whose values are given. Returns the evaluation of the interpolant at /// `evaluation_point`. - pub fn interpolate_coset( + pub(crate) fn interpolate_coset>( &mut self, subgroup_bits: usize, coset_shift: Target, values: &[ExtensionTarget], evaluation_point: ExtensionTarget, ) -> ExtensionTarget { - let gate = InterpolationGate::new(subgroup_bits); - let gate_index = self.add_gate(gate.clone(), vec![]); + let gate = G::new(subgroup_bits); + let gate_index = self.add_gate(gate, vec![]); self.connect(coset_shift, Target::wire(gate_index, gate.wire_shift())); for (i, &v) in values.iter().enumerate() { self.connect_extension( @@ -38,11 +105,14 @@ impl, const D: usize> CircuitBuilder { mod tests { use anyhow::Result; + use crate::field::extension_field::quadratic::QuadraticExtension; use crate::field::extension_field::quartic::QuarticExtension; use crate::field::extension_field::FieldExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::field::interpolation::interpolant; + use crate::gates::interpolation::HighDegreeInterpolationGate; + use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; @@ -51,10 +121,11 @@ mod tests { #[test] fn test_interpolate() -> Result<()> { type F = GoldilocksField; - type FF = QuarticExtension; + const D: usize = 2; + type FF = QuadraticExtension; let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::::new(config); let subgroup_bits = 2; let len = 1 << subgroup_bits; @@ -66,7 +137,7 @@ mod tests { let homogeneous_points = points .iter() .zip(values.iter()) - .map(|(&a, &b)| (>::from_basefield(a), b)) + .map(|(&a, &b)| (>::from_basefield(a), b)) .collect::>(); let true_interpolant = interpolant(&homogeneous_points); @@ -83,9 +154,21 @@ mod tests { let zt = builder.constant_extension(z); - let eval = builder.interpolate_coset(subgroup_bits, coset_shift_target, &value_targets, zt); + let eval_hd = builder.interpolate_coset::>( + subgroup_bits, + coset_shift_target, + &value_targets, + zt, + ); + let eval_ld = builder.interpolate_coset::>( + subgroup_bits, + coset_shift_target, + &value_targets, + zt, + ); let true_eval_target = builder.constant_extension(true_eval); - builder.connect_extension(eval, true_eval_target); + builder.connect_extension(eval_hd, true_eval_target); + builder.connect_extension(eval_ld, true_eval_target); let data = builder.build(); let proof = data.prove(pw)?; diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index 704b410a..c2a517d0 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -225,11 +225,12 @@ mod tests { use super::*; use crate::field::goldilocks_field::GoldilocksField; + use crate::gadgets::interpolation::InterpolationGate; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::base_sum::BaseSumGate; use crate::gates::constant::ConstantGate; use crate::gates::gmimc::GMiMCGate; - use crate::gates::interpolation::InterpolationGate; + use crate::gates::interpolation::HighDegreeInterpolationGate; use crate::gates::noop::NoopGate; #[test] @@ -244,7 +245,7 @@ mod tests { GateRef::new(ArithmeticExtensionGate { num_ops: 4 }), GateRef::new(BaseSumGate::<4>::new(4)), GateRef::new(GMiMCGate::::new()), - GateRef::new(InterpolationGate::new(2)), + GateRef::new(HighDegreeInterpolationGate::new(2)), ]; let (tree, _, _) = Tree::from_gates(gates.clone()); diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index d97eb009..58288bba 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -6,6 +6,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::RichField; use crate::field::interpolation::interpolant; +use crate::gadgets::interpolation::InterpolationGate; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::Gate; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; @@ -19,77 +20,13 @@ use crate::polynomial::polynomial::PolynomialCoeffs; /// Interpolates a polynomial, whose points are a (base field) coset of the multiplicative subgroup /// with the given size, and whose values are extension field elements, given by input wires. /// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. -#[derive(Clone, Debug)] -pub(crate) struct InterpolationGate, const D: usize> { +#[derive(Copy, Clone, Debug)] +pub(crate) struct HighDegreeInterpolationGate, const D: usize> { pub subgroup_bits: usize, _phantom: PhantomData, } -impl, const D: usize> InterpolationGate { - pub fn new(subgroup_bits: usize) -> Self { - Self { - subgroup_bits, - _phantom: PhantomData, - } - } - - fn num_points(&self) -> usize { - 1 << self.subgroup_bits - } - - /// Wire index of the coset shift. - pub fn wire_shift(&self) -> usize { - 0 - } - - fn start_values(&self) -> usize { - 1 - } - - /// Wire indices of the `i`th interpolant value. - pub fn wires_value(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_values() + i * D; - start..start + D - } - - fn start_evaluation_point(&self) -> usize { - self.start_values() + self.num_points() * D - } - - /// Wire indices of the point to evaluate the interpolant at. - pub fn wires_evaluation_point(&self) -> Range { - let start = self.start_evaluation_point(); - start..start + D - } - - fn start_evaluation_value(&self) -> usize { - self.start_evaluation_point() + D - } - - /// Wire indices of the interpolated value. - pub fn wires_evaluation_value(&self) -> Range { - let start = self.start_evaluation_value(); - start..start + D - } - - fn start_coeffs(&self) -> usize { - self.start_evaluation_value() + D - } - - /// The number of routed wires required in the typical usage of this gate, where the points to - /// interpolate, the evaluation point, and the corresponding value are all routed. - pub(crate) fn num_routed_wires(&self) -> usize { - self.start_coeffs() - } - - /// Wire indices of the interpolant's `i`th coefficient. - pub fn wires_coeff(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_coeffs() + i * D; - start..start + D - } - +impl, const D: usize> HighDegreeInterpolationGate { /// End of wire indices, exclusive. fn end(&self) -> usize { self.start_coeffs() + self.num_points() * D @@ -128,7 +65,77 @@ impl, const D: usize> InterpolationGate { } } -impl, const D: usize> Gate for InterpolationGate { +impl, const D: usize> InterpolationGate + for HighDegreeInterpolationGate +{ + fn new(subgroup_bits: usize) -> Self { + Self { + subgroup_bits, + _phantom: PhantomData, + } + } + + fn num_points(&self) -> usize { + 1 << self.subgroup_bits + } + + /// Wire index of the coset shift. + fn wire_shift(&self) -> usize { + 0 + } + + fn start_values(&self) -> usize { + 1 + } + + /// Wire indices of the `i`th interpolant value. + fn wires_value(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_values() + i * D; + start..start + D + } + + fn start_evaluation_point(&self) -> usize { + self.start_values() + self.num_points() * D + } + + /// Wire indices of the point to evaluate the interpolant at. + fn wires_evaluation_point(&self) -> Range { + let start = self.start_evaluation_point(); + start..start + D + } + + fn start_evaluation_value(&self) -> usize { + self.start_evaluation_point() + D + } + + /// Wire indices of the interpolated value. + fn wires_evaluation_value(&self) -> Range { + let start = self.start_evaluation_value(); + start..start + D + } + + fn start_coeffs(&self) -> usize { + self.start_evaluation_value() + D + } + + /// The number of routed wires required in the typical usage of this gate, where the points to + /// interpolate, the evaluation point, and the corresponding value are all routed. + fn num_routed_wires(&self) -> usize { + self.start_coeffs() + } + + /// Wire indices of the interpolant's `i`th coefficient. + fn wires_coeff(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_coeffs() + i * D; + start..start + D + } +} + +impl, const D: usize> Gate + for HighDegreeInterpolationGate +{ fn id(&self) -> String { format!("{:?}", self, D) } @@ -251,7 +258,7 @@ impl, const D: usize> Gate for InterpolationG #[derive(Debug)] struct InterpolationGenerator, const D: usize> { gate_index: usize, - gate: InterpolationGate, + gate: HighDegreeInterpolationGate, _phantom: PhantomData, } @@ -324,16 +331,17 @@ mod tests { use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; + use crate::gadgets::interpolation::InterpolationGate; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; - use crate::gates::interpolation::InterpolationGate; + use crate::gates::interpolation::HighDegreeInterpolationGate; use crate::hash::hash_types::HashOut; use crate::plonk::vars::EvaluationVars; use crate::polynomial::polynomial::PolynomialCoeffs; #[test] fn wire_indices() { - let gate = InterpolationGate:: { + let gate = HighDegreeInterpolationGate:: { subgroup_bits: 1, _phantom: PhantomData, }; @@ -352,12 +360,12 @@ mod tests { #[test] fn low_degree() { - test_low_degree::(InterpolationGate::new(2)); + test_low_degree::(HighDegreeInterpolationGate::new(2)); } #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(InterpolationGate::new(2)) + test_eval_fns::(HighDegreeInterpolationGate::new(2)) } #[test] @@ -368,7 +376,7 @@ mod tests { /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. fn get_wires( - gate: &InterpolationGate, + gate: &HighDegreeInterpolationGate, shift: F, coeffs: PolynomialCoeffs, eval_point: FF, @@ -390,7 +398,7 @@ mod tests { let shift = F::rand(); let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]); let eval_point = FF::rand(); - let gate = InterpolationGate::::new(1); + let gate = HighDegreeInterpolationGate::::new(1); let vars = EvaluationVars { local_constants: &[], local_wires: &get_wires(&gate, shift, coeffs, eval_point), diff --git a/src/gates/low_degree_interpolation.rs b/src/gates/low_degree_interpolation.rs index 95b168f7..d38bd773 100644 --- a/src/gates/low_degree_interpolation.rs +++ b/src/gates/low_degree_interpolation.rs @@ -1,11 +1,12 @@ use std::marker::PhantomData; use std::ops::Range; -use crate::field::extension_field::algebra::{ExtensionAlgebra, PolynomialCoeffsAlgebra}; +use crate::field::extension_field::algebra::PolynomialCoeffsAlgebra; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::{Field, RichField}; use crate::field::interpolation::interpolant; +use crate::gadgets::interpolation::InterpolationGate; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::Gate; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; @@ -19,81 +20,13 @@ use crate::polynomial::polynomial::PolynomialCoeffs; /// Interpolates a polynomial, whose points are a (base field) coset of the multiplicative subgroup /// with the given size, and whose values are extension field elements, given by input wires. /// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. -#[derive(Clone, Debug)] +#[derive(Copy, Clone, Debug)] pub(crate) struct LowDegreeInterpolationGate, const D: usize> { pub subgroup_bits: usize, _phantom: PhantomData, } impl, const D: usize> LowDegreeInterpolationGate { - pub fn new(subgroup_bits: usize) -> Self { - Self { - subgroup_bits, - _phantom: PhantomData, - } - } - - fn num_points(&self) -> usize { - 1 << self.subgroup_bits - } - - /// Wire index of the coset shift. - pub fn wire_shift(&self) -> usize { - 0 - } - - fn start_values(&self) -> usize { - 1 - } - - /// Wire indices of the `i`th interpolant value. - pub fn wires_value(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_values() + i * D; - start..start + D - } - - fn start_evaluation_point(&self) -> usize { - self.start_values() + self.num_points() * D - } - - /// Wire indices of the point to evaluate the interpolant at. - pub fn wires_evaluation_point(&self) -> Range { - let start = self.start_evaluation_point(); - start..start + D - } - - fn start_evaluation_value(&self) -> usize { - self.start_evaluation_point() + D - } - - /// Wire indices of the interpolated value. - pub fn wires_evaluation_value(&self) -> Range { - let start = self.start_evaluation_value(); - start..start + D - } - - fn start_coeffs(&self) -> usize { - self.start_evaluation_value() + D - } - - /// The number of routed wires required in the typical usage of this gate, where the points to - /// interpolate, the evaluation point, and the corresponding value are all routed. - pub(crate) fn num_routed_wires(&self) -> usize { - self.start_coeffs() - } - - /// Wire indices of the interpolant's `i`th coefficient. - pub fn wires_coeff(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_coeffs() + i * D; - start..start + D - } - - fn end_coeffs(&self) -> usize { - self.start_coeffs() + D * self.num_points() - } - pub fn powers_init(&self, i: usize) -> usize { debug_assert!(0 < i && i < self.num_points()); if i == 1 { @@ -123,29 +56,77 @@ impl, const D: usize> LowDegreeInterpolationGate impl Iterator { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - g.powers().take(size).map(move |x| shift.scalar_mul(x)) +impl, const D: usize> InterpolationGate + for LowDegreeInterpolationGate +{ + fn new(subgroup_bits: usize) -> Self { + Self { + subgroup_bits, + _phantom: PhantomData, + } } - /// The domain of the points we're interpolating. - fn coset_ext_recursive( - &self, - builder: &mut CircuitBuilder, - shift: ExtensionTarget, - ) -> Vec> { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - g.powers() - .take(size) - .map(move |x| { - let subgroup_element = builder.constant(x.into()); - builder.scalar_mul_ext(subgroup_element, shift) - }) - .collect() + fn num_points(&self) -> usize { + 1 << self.subgroup_bits + } + + /// Wire index of the coset shift. + fn wire_shift(&self) -> usize { + 0 + } + + fn start_values(&self) -> usize { + 1 + } + + /// Wire indices of the `i`th interpolant value. + fn wires_value(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_values() + i * D; + start..start + D + } + + fn start_evaluation_point(&self) -> usize { + self.start_values() + self.num_points() * D + } + + /// Wire indices of the point to evaluate the interpolant at. + fn wires_evaluation_point(&self) -> Range { + let start = self.start_evaluation_point(); + start..start + D + } + + fn start_evaluation_value(&self) -> usize { + self.start_evaluation_point() + D + } + + /// Wire indices of the interpolated value. + fn wires_evaluation_value(&self) -> Range { + let start = self.start_evaluation_value(); + start..start + D + } + + fn start_coeffs(&self) -> usize { + self.start_evaluation_value() + D + } + + /// The number of routed wires required in the typical usage of this gate, where the points to + /// interpolate, the evaluation point, and the corresponding value are all routed. + fn num_routed_wires(&self) -> usize { + self.start_coeffs() + } + + /// Wire indices of the interpolant's `i`th coefficient. + fn wires_coeff(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_coeffs() + i * D; + start..start + D + } + + fn end_coeffs(&self) -> usize { + self.start_coeffs() + D * self.num_points() } } @@ -186,7 +167,7 @@ impl, const D: usize> Gate for LowDegreeInter constraints.extend(&(value - computed_value).to_basefield_array()); } - let mut evaluation_point_powers = (1..self.num_points()) + let evaluation_point_powers = (1..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) .collect::>(); let evaluation_point = evaluation_point_powers[0]; @@ -408,11 +389,13 @@ impl, const D: usize> SimpleGenerator let wire_shift = get_local_wire(self.gate.wire_shift()); - for i in 2..self.gate.num_points() { - out_buffer.set_wire( - local_wire(self.gate.powers_init(i)), - wire_shift.exp_u64(i as u64), - ); + for (i, power) in wire_shift + .powers() + .take(self.gate.num_points()) + .enumerate() + .skip(2) + { + out_buffer.set_wire(local_wire(self.gate.powers_init(i)), power); } // Compute the interpolant. @@ -452,6 +435,7 @@ mod tests { use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; + use crate::gadgets::interpolation::InterpolationGate; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index f19d6ae9..32ea59b6 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -687,7 +687,6 @@ impl, const D: usize> CircuitBuilder { marked_targets: self.marked_targets, representative_map: forest.parents, fft_root_table: Some(fft_root_table), - instances: self.gate_instances, }; // The HashSet of gates will have a non-deterministic order. When converting to a Vec, we diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 81d8b09f..d1b03570 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -9,7 +9,7 @@ use crate::field::field_types::{Field, RichField}; use crate::fri::commitment::PolynomialBatchCommitment; use crate::fri::reduction_strategies::FriReductionStrategy; use crate::fri::{FriConfig, FriParams}; -use crate::gates::gate::{GateInstance, PrefixedGate}; +use crate::gates::gate::PrefixedGate; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::generator::WitnessGenerator; @@ -177,7 +177,6 @@ pub(crate) struct ProverOnlyCircuitData, const D: u pub representative_map: Vec, /// Pre-computed roots for faster FFT. pub fft_root_table: Option>, - pub instances: Vec>, } /// Circuit data required by the verifier, but not the prover. From e06ce5aa2f549ff90afa2fff3f9ac4b9e3929e69 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 22 Nov 2021 16:41:33 +0100 Subject: [PATCH 126/202] Fix proof compression test --- src/plonk/proof.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/plonk/proof.rs b/src/plonk/proof.rs index 815f807d..70e38588 100644 --- a/src/plonk/proof.rs +++ b/src/plonk/proof.rs @@ -310,6 +310,7 @@ mod tests { use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::fri::reduction_strategies::FriReductionStrategy; + use crate::gates::noop::NoopGate; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; @@ -336,6 +337,9 @@ mod tests { let zt = builder.constant(z); let comp_zt = builder.mul(xt, yt); builder.connect(zt, comp_zt); + for _ in 0..100 { + builder.add_gate(NoopGate, vec![]); + } let data = builder.build(); let proof = data.prove(pw)?; verify(proof.clone(), &data.verifier_only, &data.common)?; From b7cb7e234f7afda34bd801679ffbf2c1ef23d9eb Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 22 Nov 2021 17:06:40 +0100 Subject: [PATCH 127/202] Minor --- src/gadgets/interpolation.rs | 1 - src/gates/low_degree_interpolation.rs | 3 --- 2 files changed, 4 deletions(-) diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 5d23b65b..9167e771 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -106,7 +106,6 @@ mod tests { use anyhow::Result; use crate::field::extension_field::quadratic::QuadraticExtension; - use crate::field::extension_field::quartic::QuarticExtension; use crate::field::extension_field::FieldExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; diff --git a/src/gates/low_degree_interpolation.rs b/src/gates/low_degree_interpolation.rs index d38bd773..1a869941 100644 --- a/src/gates/low_degree_interpolation.rs +++ b/src/gates/low_degree_interpolation.rs @@ -427,12 +427,9 @@ impl, const D: usize> SimpleGenerator #[cfg(test)] mod tests { - use std::marker::PhantomData; - use anyhow::Result; use crate::field::extension_field::quadratic::QuadraticExtension; - use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::gadgets::interpolation::InterpolationGate; From 5ea632f2a866745440cd33323ace7f55e8fac020 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 22 Nov 2021 17:30:13 +0100 Subject: [PATCH 128/202] Fix size optimized test --- src/gates/gate_testing.rs | 3 --- src/gates/low_degree_interpolation.rs | 2 -- src/plonk/circuit_data.rs | 13 +++++++++++++ src/plonk/recursive_verifier.rs | 4 ++-- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/gates/gate_testing.rs b/src/gates/gate_testing.rs index 15c17ba7..9fe8e835 100644 --- a/src/gates/gate_testing.rs +++ b/src/gates/gate_testing.rs @@ -51,9 +51,7 @@ pub(crate) fn test_low_degree, G: Gate, const ); let expected_eval_degree = WITNESS_DEGREE * gate.degree(); - dbg!(WITNESS_DEGREE, gate.degree()); - dbg!(&constraint_eval_degrees); assert!( constraint_eval_degrees .iter() @@ -153,7 +151,6 @@ pub(crate) fn test_eval_fns, G: Gate, const D let evals_t = gate.eval_unfiltered_recursively(&mut builder, vars_t); pw.set_extension_targets(&evals_t, &evals); - dbg!(builder.num_gates()); let data = builder.build(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) diff --git a/src/gates/low_degree_interpolation.rs b/src/gates/low_degree_interpolation.rs index 1a869941..c79a2a37 100644 --- a/src/gates/low_degree_interpolation.rs +++ b/src/gates/low_degree_interpolation.rs @@ -490,8 +490,6 @@ mod tests { let coeffs = PolynomialCoeffs::new(FF::rand_vec(1 << subgroup_bits)); let eval_point = FF::rand(); let gate = LowDegreeInterpolationGate::::new(subgroup_bits); - dbg!(gate.end_coeffs()); - dbg!(gate.powers_eval(15)); let vars = EvaluationVars { local_constants: &[], local_wires: &get_wires(&gate, shift, coeffs, eval_point), diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index d1b03570..41e89ea2 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -76,6 +76,19 @@ impl CircuitConfig { } } + pub fn size_optimized_recursion_config() -> Self { + Self { + security_bits: 93, + cap_height: 3, + fri_config: FriConfig { + reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), + num_query_rounds: 26, + ..CircuitConfig::standard_recursion_config().fri_config + }, + ..CircuitConfig::standard_recursion_config() + } + } + pub fn standard_recursion_zk_config() -> Self { CircuitConfig { zero_knowledge: true, diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 2dc0223b..98c866f8 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -408,7 +408,7 @@ mod tests { type F = GoldilocksField; const D: usize = 2; - let standard_config = CircuitConfig::standard_recursion_config(); + let standard_config = CircuitConfig::size_optimized_recursion_config(); // An initial dummy proof. let (proof, vd, cd) = dummy_proof::(&standard_config, 4_000)?; @@ -456,7 +456,7 @@ mod tests { num_routed_wires: 25, fri_config: FriConfig { proof_of_work_bits: 21, - reduction_strategy: FriReductionStrategy::MinSize(None), + reduction_strategy: FriReductionStrategy::MinSize(Some(3)), num_query_rounds: 9, }, ..high_rate_config From fa29db1dcb51ef711551d6837abf25b1e0b6389f Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 22 Nov 2021 17:54:58 +0100 Subject: [PATCH 129/202] Clean low-degree interpolation gate --- src/gates/low_degree_interpolation.rs | 159 ++++++++++++++------------ 1 file changed, 86 insertions(+), 73 deletions(-) diff --git a/src/gates/low_degree_interpolation.rs b/src/gates/low_degree_interpolation.rs index c79a2a37..245c6d17 100644 --- a/src/gates/low_degree_interpolation.rs +++ b/src/gates/low_degree_interpolation.rs @@ -26,38 +26,6 @@ pub(crate) struct LowDegreeInterpolationGate, const _phantom: PhantomData, } -impl, const D: usize> LowDegreeInterpolationGate { - pub fn powers_init(&self, i: usize) -> usize { - debug_assert!(0 < i && i < self.num_points()); - if i == 1 { - return self.wire_shift(); - } - self.end_coeffs() + i - 2 - } - - pub fn powers_eval(&self, i: usize) -> Range { - debug_assert!(0 < i && i < self.num_points()); - if i == 1 { - return self.wires_evaluation_point(); - } - let start = self.end_coeffs() + self.num_points() - 2 + (i - 2) * D; - start..start + D - } - - /// End of wire indices, exclusive. - fn end(&self) -> usize { - self.powers_eval(self.num_points() - 1).end - } - - /// The domain of the points we're interpolating. - fn coset(&self, shift: F) -> impl Iterator { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. - g.powers().take(size).map(move |x| x * shift) - } -} - impl, const D: usize> InterpolationGate for LowDegreeInterpolationGate { @@ -130,6 +98,40 @@ impl, const D: usize> InterpolationGate } } +impl, const D: usize> LowDegreeInterpolationGate { + /// `powers_shift(i)` is the wire index of `wire_shift^i`. + pub fn powers_shift(&self, i: usize) -> usize { + debug_assert!(0 < i && i < self.num_points()); + if i == 1 { + return self.wire_shift(); + } + self.end_coeffs() + i - 2 + } + + /// `powers_evalutation_point(i)` is the wire index of `evalutation_point^i`. + pub fn powers_evaluation_point(&self, i: usize) -> Range { + debug_assert!(0 < i && i < self.num_points()); + if i == 1 { + return self.wires_evaluation_point(); + } + let start = self.end_coeffs() + self.num_points() - 2 + (i - 2) * D; + start..start + D + } + + /// End of wire indices, exclusive. + fn end(&self) -> usize { + self.powers_evaluation_point(self.num_points() - 1).end + } + + /// The domain of the points we're interpolating. + fn coset(&self, shift: F) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. + g.powers().take(size).map(move |x| x * shift) + } +} + impl, const D: usize> Gate for LowDegreeInterpolationGate { fn id(&self) -> String { format!("{:?}", self, D) @@ -142,33 +144,35 @@ impl, const D: usize> Gate for LowDegreeInter .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) .collect::>(); - let mut powers_init = (1..self.num_points()) - .map(|i| vars.local_wires[self.powers_init(i)]) + let mut powers_shift = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_shift(i)]) .collect::>(); - powers_init.insert(0, F::Extension::ONE); - let wire_shift = powers_init[1]; - for i in 2..self.num_points() { - constraints.push(powers_init[i - 1] * wire_shift - powers_init[i]); + let shift = powers_shift[0]; + for i in 1..self.num_points() - 1 { + constraints.push(powers_shift[i - 1] * shift - powers_shift[i]); } - let ocoeffs = coeffs + powers_shift.insert(0, F::Extension::ONE); + // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. + // Then, `altered(w^i) = original(shift*w^i)`. + let altered_coeffs = coeffs .iter() - .zip(powers_init) + .zip(powers_shift) .map(|(&c, p)| c.scalar_mul(p)) .collect::>(); let interpolant = PolynomialCoeffsAlgebra::new(coeffs); - let ointerpolant = PolynomialCoeffsAlgebra::new(ocoeffs); + let altered_interpolant = PolynomialCoeffsAlgebra::new(altered_coeffs); for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) .into_iter() .enumerate() { let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = ointerpolant.eval_base(point); + let computed_value = altered_interpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); } let evaluation_point_powers = (1..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) + .map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i))) .collect::>(); let evaluation_point = evaluation_point_powers[0]; for i in 1..self.num_points() - 1 { @@ -190,33 +194,36 @@ impl, const D: usize> Gate for LowDegreeInter let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext(self.wires_coeff(i))) .collect::>(); - let mut powers_init = (1..self.num_points()) - .map(|i| vars.local_wires[self.powers_init(i)]) + + let mut powers_shift = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_shift(i)]) .collect::>(); - powers_init.insert(0, F::ONE); - let wire_shift = powers_init[1]; - for i in 2..self.num_points() { - constraints.push(powers_init[i - 1] * wire_shift - powers_init[i]); + let shift = powers_shift[0]; + for i in 1..self.num_points() - 1 { + constraints.push(powers_shift[i - 1] * shift - powers_shift[i]); } - let ocoeffs = coeffs + powers_shift.insert(0, F::ONE); + // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. + // Then, `altered(w^i) = original(shift*w^i)`. + let altered_coeffs = coeffs .iter() - .zip(powers_init) + .zip(powers_shift) .map(|(&c, p)| c.scalar_mul(p)) .collect::>(); let interpolant = PolynomialCoeffs::new(coeffs); - let ointerpolant = PolynomialCoeffs::new(ocoeffs); + let altered_interpolant = PolynomialCoeffs::new(altered_coeffs); for (i, point) in F::two_adic_subgroup(self.subgroup_bits) .into_iter() .enumerate() { let value = vars.get_local_ext(self.wires_value(i)); - let computed_value = ointerpolant.eval_base(point); + let computed_value = altered_interpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); } let evaluation_point_powers = (1..self.num_points()) - .map(|i| vars.get_local_ext(self.powers_eval(i))) + .map(|i| vars.get_local_ext(self.powers_evaluation_point(i))) .collect::>(); let evaluation_point = evaluation_point_powers[0]; for i in 1..self.num_points() - 1 { @@ -242,25 +249,28 @@ impl, const D: usize> Gate for LowDegreeInter let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) .collect::>(); - let mut powers_init = (1..self.num_points()) - .map(|i| vars.local_wires[self.powers_init(i)]) + + let mut powers_shift = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_shift(i)]) .collect::>(); - powers_init.insert(0, builder.one_extension()); - let wire_shift = powers_init[1]; - for i in 2..self.num_points() { + let shift = powers_shift[0]; + for i in 1..self.num_points() - 1 { constraints.push(builder.mul_sub_extension( - powers_init[i - 1], - wire_shift, - powers_init[i], + powers_shift[i - 1], + shift, + powers_shift[i], )); } - let ocoeffs = coeffs + powers_shift.insert(0, builder.one_extension()); + // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. + // Then, `altered(w^i) = original(shift*w^i)`. + let altered_coeffs = coeffs .iter() - .zip(powers_init) + .zip(powers_shift) .map(|(&c, p)| builder.scalar_mul_ext_algebra(p, c)) .collect::>(); let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); - let ointerpolant = PolynomialCoeffsExtAlgebraTarget(ocoeffs); + let altered_interpolant = PolynomialCoeffsExtAlgebraTarget(altered_coeffs); for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) .into_iter() @@ -268,7 +278,7 @@ impl, const D: usize> Gate for LowDegreeInter { let value = vars.get_local_ext_algebra(self.wires_value(i)); let point = builder.constant_extension(point); - let computed_value = ointerpolant.eval_scalar(builder, point); + let computed_value = altered_interpolant.eval_scalar(builder, point); constraints.extend( &builder .sub_ext_algebra(value, computed_value) @@ -277,7 +287,7 @@ impl, const D: usize> Gate for LowDegreeInter } let evaluation_point_powers = (1..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) + .map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i))) .collect::>(); let evaluation_point = evaluation_point_powers[0]; for i in 1..self.num_points() - 1 { @@ -328,8 +338,6 @@ impl, const D: usize> Gate for LowDegreeInter } fn degree(&self) -> usize { - // The highest power of x is `num_points - 1`, and then multiplication by the coefficient - // adds 1. 2 } @@ -395,7 +403,7 @@ impl, const D: usize> SimpleGenerator .enumerate() .skip(2) { - out_buffer.set_wire(local_wire(self.gate.powers_init(i)), power); + out_buffer.set_wire(local_wire(self.gate.powers_shift(i)), power); } // Compute the interpolant. @@ -413,10 +421,15 @@ impl, const D: usize> SimpleGenerator } let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); - for i in 2..self.gate.num_points() { + for (i, power) in evaluation_point + .powers() + .take(self.gate.num_points()) + .enumerate() + .skip(2) + { out_buffer.set_extension_target( - ExtensionTarget::from_range(self.gate_index, self.gate.powers_eval(i)), - evaluation_point.exp_u64(i as u64), + ExtensionTarget::from_range(self.gate_index, self.gate.powers_evaluation_point(i)), + power, ); } let evaluation_value = interpolant.eval(evaluation_point); From 172fdd3d89ea32c9f3ee0b9dd4b91e7e8d753c92 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 22 Nov 2021 21:20:44 +0100 Subject: [PATCH 130/202] Comments --- src/field/extension_field/algebra.rs | 2 ++ src/fri/recursive_verifier.rs | 27 +++++++++++++++++++-------- src/gadgets/interpolation.rs | 3 +++ src/gadgets/polynomial.rs | 7 ++----- src/gates/interpolation.rs | 5 ++--- src/gates/low_degree_interpolation.rs | 7 +++---- src/plonk/prover.rs | 11 ----------- src/polynomial/polynomial.rs | 2 ++ 8 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/field/extension_field/algebra.rs b/src/field/extension_field/algebra.rs index 37f70f52..93d25de4 100644 --- a/src/field/extension_field/algebra.rs +++ b/src/field/extension_field/algebra.rs @@ -160,6 +160,7 @@ impl, const D: usize> PolynomialCoeffsAlgebra { .fold(ExtensionAlgebra::ZERO, |acc, &c| acc * x + c) } + /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. pub fn eval_with_powers(&self, powers: &[ExtensionAlgebra]) -> ExtensionAlgebra { debug_assert_eq!(self.coeffs.len(), powers.len() + 1); let acc = self.coeffs[0]; @@ -176,6 +177,7 @@ impl, const D: usize> PolynomialCoeffsAlgebra { .fold(ExtensionAlgebra::ZERO, |acc, &c| acc.scalar_mul(x) + c) } + /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. pub fn eval_base_with_powers(&self, powers: &[F]) -> ExtensionAlgebra { debug_assert_eq!(self.coeffs.len(), powers.len() + 1); let acc = self.coeffs[0]; diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index f543b12b..b31e01a8 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -46,7 +46,9 @@ impl, const D: usize> CircuitBuilder { let coset_start = self.mul(start, x); // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. - if 1 << arity_bits > common_data.quotient_degree_factor { + // `HighDegreeInterpolationGate` has degree `arity`, so we use the low-degree gate if + // the arity is too large. + if arity > common_data.quotient_degree_factor { self.interpolate_coset::>( arity_bits, coset_start, @@ -66,19 +68,28 @@ impl, const D: usize> CircuitBuilder { /// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check /// isn't required -- without it we'd get errors elsewhere in the stack -- but just gives more /// helpful errors. - fn check_recursion_config(&self, max_fri_arity_bits: usize) { + fn check_recursion_config( + &self, + max_fri_arity_bits: usize, + common_data: &CommonCircuitData, + ) { let random_access = RandomAccessGate::::new_from_config( &self.config, max_fri_arity_bits.max(self.config.cap_height), ); - let interpolation_gate = HighDegreeInterpolationGate::::new(max_fri_arity_bits); + let (interpolation_wires, interpolation_routed_wires) = + if 1 << max_fri_arity_bits > common_data.quotient_degree_factor { + let gate = LowDegreeInterpolationGate::::new(max_fri_arity_bits); + (gate.num_wires(), gate.num_routed_wires()) + } else { + let gate = HighDegreeInterpolationGate::::new(max_fri_arity_bits); + (gate.num_wires(), gate.num_routed_wires()) + }; - let min_wires = random_access - .num_wires() - .max(interpolation_gate.num_wires()); + let min_wires = random_access.num_wires().max(interpolation_wires); let min_routed_wires = random_access .num_routed_wires() - .max(interpolation_gate.num_routed_wires()); + .max(interpolation_routed_wires); assert!( self.config.num_wires >= min_wires, @@ -125,7 +136,7 @@ impl, const D: usize> CircuitBuilder { let config = &common_data.config; if let Some(max_arity_bits) = common_data.fri_params.max_arity_bits() { - self.check_recursion_config(max_arity_bits); + self.check_recursion_config(max_arity_bits, common_data); } debug_assert_eq!( diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 9167e771..7f6f98f3 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -7,6 +7,9 @@ use crate::gates::gate::Gate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; +/// Trait for gates which interpolate a polynomial, whose points are a (base field) coset of the multiplicative subgroup +/// with the given size, and whose values are extension field elements, given by input wires. +/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. pub(crate) trait InterpolationGate, const D: usize>: Gate + Copy { diff --git a/src/gadgets/polynomial.rs b/src/gadgets/polynomial.rs index 48a5f8b7..ff7fffd7 100644 --- a/src/gadgets/polynomial.rs +++ b/src/gadgets/polynomial.rs @@ -64,6 +64,8 @@ impl PolynomialCoeffsExtAlgebraTarget { } acc } + + /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. pub fn eval_with_powers( &self, builder: &mut CircuitBuilder, @@ -78,10 +80,5 @@ impl PolynomialCoeffsExtAlgebraTarget { .iter() .zip(powers) .fold(acc, |acc, (&x, &c)| builder.mul_add_ext_algebra(c, x, acc)) - // let mut acc = builder.zero_ext_algebra(); - // for &c in self.0.iter().rev() { - // acc = builder.mul_add_ext_algebra(point, acc, c); - // } - // acc } } diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 58288bba..f7934b71 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -17,9 +17,8 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::polynomial::polynomial::PolynomialCoeffs; -/// Interpolates a polynomial, whose points are a (base field) coset of the multiplicative subgroup -/// with the given size, and whose values are extension field elements, given by input wires. -/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. +/// Interpolation gate with constraints of degree at most `1<, const D: usize> { pub subgroup_bits: usize, diff --git a/src/gates/low_degree_interpolation.rs b/src/gates/low_degree_interpolation.rs index 245c6d17..abbdb00e 100644 --- a/src/gates/low_degree_interpolation.rs +++ b/src/gates/low_degree_interpolation.rs @@ -17,9 +17,8 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::polynomial::polynomial::PolynomialCoeffs; -/// Interpolates a polynomial, whose points are a (base field) coset of the multiplicative subgroup -/// with the given size, and whose values are extension field elements, given by input wires. -/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. +/// Interpolation gate with constraints of degree 2. +/// `eval_unfiltered_recursively` uses more gates than `HighDegreeInterpolationGate`. #[derive(Copy, Clone, Debug)] pub(crate) struct LowDegreeInterpolationGate, const D: usize> { pub subgroup_bits: usize, @@ -344,7 +343,7 @@ impl, const D: usize> Gate for LowDegreeInter fn num_constraints(&self) -> usize { // `num_points * D` constraints to check for consistency between the coefficients and the // point-value pairs, plus `D` constraints for the evaluation value, plus `(D+1)*(num_points-2)` - // to check power constraints for evaluation point and wire shift. + // to check power constraints for evaluation point and shift. self.num_points() * D + D + (D + 1) * (self.num_points() - 2) } } diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index b5175f02..3f8e607d 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -66,17 +66,6 @@ pub(crate) fn prove, const D: usize>( .collect() ); - // let rows = (0..degree) - // .map(|i| wires_values.iter().map(|w| w.values[i]).collect::>()) - // .collect::>(); - // for (i, r) in rows.iter().enumerate() { - // let c = rows.iter().filter(|&x| x == r).count(); - // let s = prover_data.instances[i].gate_ref.0.id(); - // if c > 1 && !s.starts_with("Noop") { - // println!("{} {} {}", prover_data.instances[i].gate_ref.0.id(), i, c); - // } - // } - let wires_commitment = timed!( timing, "compute wires commitment", diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index 3c732208..b5a7fabd 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -120,6 +120,7 @@ impl PolynomialCoeffs { .fold(F::ZERO, |acc, &c| acc * x + c) } + /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. pub fn eval_with_powers(&self, powers: &[F]) -> F { debug_assert_eq!(self.coeffs.len(), powers.len() + 1); let acc = self.coeffs[0]; @@ -139,6 +140,7 @@ impl PolynomialCoeffs { .fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c) } + /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. pub fn eval_base_with_powers(&self, powers: &[F::BaseField]) -> F where F: FieldExtension, From 15b41ea8fb11716ab9f2268aa37ce39974bac530 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 22 Nov 2021 22:13:24 +0100 Subject: [PATCH 131/202] PR feedback --- src/gadgets/arithmetic_extension.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index cc82be68..d81943ab 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -46,17 +46,17 @@ impl, const D: usize> CircuitBuilder { let result = if self.target_as_constant_ext(addend) == Some(F::Extension::ZERO) { // If the addend is zero, we use a multiplication gate. - self.add_mul_extension_operation(operation) + self.compute_mul_extension_operation(operation) } else { // Otherwise, we use an arithmetic gate. - self.add_arithmetic_extension_operation(operation) + self.compute_arithmetic_extension_operation(operation) }; // Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot. self.arithmetic_results.insert(operation, result); result } - fn add_arithmetic_extension_operation( + fn compute_arithmetic_extension_operation( &mut self, operation: ExtensionArithmeticOperation, ) -> ExtensionTarget { @@ -79,7 +79,7 @@ impl, const D: usize> CircuitBuilder { ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_ith_output(i)) } - fn add_mul_extension_operation( + fn compute_mul_extension_operation( &mut self, operation: ExtensionArithmeticOperation, ) -> ExtensionTarget { From 9cafe9773178820b2abe58a2d7f526aa7ca6bb6d Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 22 Nov 2021 22:30:32 +0100 Subject: [PATCH 132/202] Remove specific impls of `InterpolationGate` --- src/gates/interpolation.rs | 85 +++++---------------------- src/gates/low_degree_interpolation.rs | 59 +------------------ 2 files changed, 17 insertions(+), 127 deletions(-) diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index f7934b71..0225ce59 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -25,6 +25,21 @@ pub(crate) struct HighDegreeInterpolationGate, cons _phantom: PhantomData, } +impl, const D: usize> InterpolationGate + for HighDegreeInterpolationGate +{ + fn new(subgroup_bits: usize) -> Self { + Self { + subgroup_bits, + _phantom: PhantomData, + } + } + + fn num_points(&self) -> usize { + 1 << self.subgroup_bits + } +} + impl, const D: usize> HighDegreeInterpolationGate { /// End of wire indices, exclusive. fn end(&self) -> usize { @@ -64,74 +79,6 @@ impl, const D: usize> HighDegreeInterpolationGate, const D: usize> InterpolationGate - for HighDegreeInterpolationGate -{ - fn new(subgroup_bits: usize) -> Self { - Self { - subgroup_bits, - _phantom: PhantomData, - } - } - - fn num_points(&self) -> usize { - 1 << self.subgroup_bits - } - - /// Wire index of the coset shift. - fn wire_shift(&self) -> usize { - 0 - } - - fn start_values(&self) -> usize { - 1 - } - - /// Wire indices of the `i`th interpolant value. - fn wires_value(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_values() + i * D; - start..start + D - } - - fn start_evaluation_point(&self) -> usize { - self.start_values() + self.num_points() * D - } - - /// Wire indices of the point to evaluate the interpolant at. - fn wires_evaluation_point(&self) -> Range { - let start = self.start_evaluation_point(); - start..start + D - } - - fn start_evaluation_value(&self) -> usize { - self.start_evaluation_point() + D - } - - /// Wire indices of the interpolated value. - fn wires_evaluation_value(&self) -> Range { - let start = self.start_evaluation_value(); - start..start + D - } - - fn start_coeffs(&self) -> usize { - self.start_evaluation_value() + D - } - - /// The number of routed wires required in the typical usage of this gate, where the points to - /// interpolate, the evaluation point, and the corresponding value are all routed. - fn num_routed_wires(&self) -> usize { - self.start_coeffs() - } - - /// Wire indices of the interpolant's `i`th coefficient. - fn wires_coeff(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_coeffs() + i * D; - start..start + D - } -} - impl, const D: usize> Gate for HighDegreeInterpolationGate { @@ -227,7 +174,7 @@ impl, const D: usize> Gate ) -> Vec>> { let gen = InterpolationGenerator:: { gate_index, - gate: self.clone(), + gate: *self, _phantom: PhantomData, }; vec![Box::new(gen.adapter())] diff --git a/src/gates/low_degree_interpolation.rs b/src/gates/low_degree_interpolation.rs index abbdb00e..1c2b41e4 100644 --- a/src/gates/low_degree_interpolation.rs +++ b/src/gates/low_degree_interpolation.rs @@ -38,63 +38,6 @@ impl, const D: usize> InterpolationGate fn num_points(&self) -> usize { 1 << self.subgroup_bits } - - /// Wire index of the coset shift. - fn wire_shift(&self) -> usize { - 0 - } - - fn start_values(&self) -> usize { - 1 - } - - /// Wire indices of the `i`th interpolant value. - fn wires_value(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_values() + i * D; - start..start + D - } - - fn start_evaluation_point(&self) -> usize { - self.start_values() + self.num_points() * D - } - - /// Wire indices of the point to evaluate the interpolant at. - fn wires_evaluation_point(&self) -> Range { - let start = self.start_evaluation_point(); - start..start + D - } - - fn start_evaluation_value(&self) -> usize { - self.start_evaluation_point() + D - } - - /// Wire indices of the interpolated value. - fn wires_evaluation_value(&self) -> Range { - let start = self.start_evaluation_value(); - start..start + D - } - - fn start_coeffs(&self) -> usize { - self.start_evaluation_value() + D - } - - /// The number of routed wires required in the typical usage of this gate, where the points to - /// interpolate, the evaluation point, and the corresponding value are all routed. - fn num_routed_wires(&self) -> usize { - self.start_coeffs() - } - - /// Wire indices of the interpolant's `i`th coefficient. - fn wires_coeff(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_coeffs() + i * D; - start..start + D - } - - fn end_coeffs(&self) -> usize { - self.start_coeffs() + D * self.num_points() - } } impl, const D: usize> LowDegreeInterpolationGate { @@ -322,7 +265,7 @@ impl, const D: usize> Gate for LowDegreeInter ) -> Vec>> { let gen = InterpolationGenerator:: { gate_index, - gate: self.clone(), + gate: *self, _phantom: PhantomData, }; vec![Box::new(gen.adapter())] From 3235a21d2b0451bbdf361d3e7df9f7ddf30de8bc Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 22 Nov 2021 22:38:37 +0100 Subject: [PATCH 133/202] 2^12 shrinking recursion with 100 bits of security --- src/plonk/circuit_data.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 41e89ea2..bf8024df 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -63,15 +63,15 @@ impl CircuitConfig { num_routed_wires: 80, constant_gate_size: 5, use_base_arithmetic_gate: true, - security_bits: 96, + security_bits: 100, rate_bits: 3, num_challenges: 2, zero_knowledge: false, cap_height: 4, fri_config: FriConfig { - proof_of_work_bits: 15, + proof_of_work_bits: 16, reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5), - num_query_rounds: 27, + num_query_rounds: 28, }, } } From 549ce0d8e9f0a26e7ce810ba0620079a421c3177 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Tue, 23 Nov 2021 21:36:12 -0800 Subject: [PATCH 134/202] Interleaved batched multiplicative inverse (#371) * Interleaved batched multiplicative inverse * Minor: typo --- benches/field_arithmetic.rs | 60 ++++++++++++++++++++++++ src/field/field_testing.rs | 15 +++--- src/field/field_types.rs | 93 ++++++++++++++++++++++++++++++------- 3 files changed, 144 insertions(+), 24 deletions(-) diff --git a/benches/field_arithmetic.rs b/benches/field_arithmetic.rs index ebecb871..2fb4a24b 100644 --- a/benches/field_arithmetic.rs +++ b/benches/field_arithmetic.rs @@ -112,6 +112,66 @@ pub(crate) fn bench_field(c: &mut Criterion) { c.bench_function(&format!("try_inverse<{}>", type_name::()), |b| { b.iter_batched(|| F::rand(), |x| x.try_inverse(), BatchSize::SmallInput) }); + + c.bench_function( + &format!("batch_multiplicative_inverse-tiny<{}>", type_name::()), + |b| { + b.iter_batched( + || (0..2).into_iter().map(|_| F::rand()).collect::>(), + |x| F::batch_multiplicative_inverse(&x), + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!("batch_multiplicative_inverse-small<{}>", type_name::()), + |b| { + b.iter_batched( + || (0..4).into_iter().map(|_| F::rand()).collect::>(), + |x| F::batch_multiplicative_inverse(&x), + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!("batch_multiplicative_inverse-medium<{}>", type_name::()), + |b| { + b.iter_batched( + || (0..16).into_iter().map(|_| F::rand()).collect::>(), + |x| F::batch_multiplicative_inverse(&x), + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!("batch_multiplicative_inverse-large<{}>", type_name::()), + |b| { + b.iter_batched( + || (0..256).into_iter().map(|_| F::rand()).collect::>(), + |x| F::batch_multiplicative_inverse(&x), + BatchSize::LargeInput, + ) + }, + ); + + c.bench_function( + &format!("batch_multiplicative_inverse-huge<{}>", type_name::()), + |b| { + b.iter_batched( + || { + (0..65536) + .into_iter() + .map(|_| F::rand()) + .collect::>() + }, + |x| F::batch_multiplicative_inverse(&x), + BatchSize::LargeInput, + ) + }, + ); } fn criterion_benchmark(c: &mut Criterion) { diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index f422d810..d15b712a 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -13,12 +13,15 @@ macro_rules! test_field_arithmetic { #[test] fn batch_inversion() { - let xs = (1..=3) - .map(|i| <$field>::from_canonical_u64(i)) - .collect::>(); - let invs = <$field>::batch_multiplicative_inverse(&xs); - for (x, inv) in xs.into_iter().zip(invs) { - assert_eq!(x * inv, <$field>::ONE); + for n in 0..20 { + let xs = (1..=n as u64) + .map(|i| <$field>::from_canonical_u64(i)) + .collect::>(); + let invs = <$field>::batch_multiplicative_inverse(&xs); + assert_eq!(invs.len(), n); + for (x, inv) in xs.into_iter().zip(invs) { + assert_eq!(x * inv, <$field>::ONE); + } } } diff --git a/src/field/field_types.rs b/src/field/field_types.rs index bbb1604e..45839bd9 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -102,34 +102,91 @@ pub trait Field: // This is Montgomery's trick. At a high level, we invert the product of the given field // elements, then derive the individual inverses from that via multiplication. + // The usual Montgomery trick involves calculating an array of cumulative products, + // resulting in a long dependency chain. To increase instruction-level parallelism, we + // compute WIDTH separate cumulative product arrays that only meet at the end. + + // Higher WIDTH increases instruction-level parallelism, but too high a value will cause us + // to run out of registers. + const WIDTH: usize = 4; + // JN note: WIDTH is 4. The code is specialized to this value and will need + // modification if it is changed. I tried to make it more generic, but Rust's const + // generics are not yet good enough. + + // Handle special cases. Paradoxically, below is repetitive but concise. + // The branches should be very predictable. let n = x.len(); if n == 0 { return Vec::new(); - } - if n == 1 { + } else if n == 1 { return vec![x[0].inverse()]; + } else if n == 2 { + let x01 = x[0] * x[1]; + let x01inv = x01.inverse(); + return vec![x01inv * x[1], x01inv * x[0]]; + } else if n == 3 { + let x01 = x[0] * x[1]; + let x012 = x01 * x[2]; + let x012inv = x012.inverse(); + let x01inv = x012inv * x[2]; + return vec![x01inv * x[1], x01inv * x[0], x012inv * x01]; } + debug_assert!(n >= WIDTH); - // Fill buf with cumulative product of x. - let mut buf = Vec::with_capacity(n); - let mut cumul_prod = x[0]; - buf.push(cumul_prod); - for i in 1..n { - cumul_prod *= x[i]; - buf.push(cumul_prod); + // Buf is reused for a few things to save allocations. + // Fill buf with cumulative product of x, only taking every 4th value. Concretely, buf will + // be [ + // x[0], x[1], x[2], x[3], + // x[0] * x[4], x[1] * x[5], x[2] * x[6], x[3] * x[7], + // x[0] * x[4] * x[8], x[1] * x[5] * x[9], x[2] * x[6] * x[10], x[3] * x[7] * x[11], + // ... + // ]. + // If n is not a multiple of WIDTH, the result is truncated from the end. For example, + // for n == 5, we get [x[0], x[1], x[2], x[3], x[0] * x[4]]. + let mut buf: Vec = Vec::with_capacity(n); + // cumul_prod holds the last WIDTH elements of buf. This is redundant, but it's how we + // convince LLVM to keep the values in the registers. + let mut cumul_prod: [Self; WIDTH] = x[..WIDTH].try_into().unwrap(); + buf.extend(cumul_prod); + for (i, &xi) in x[WIDTH..].iter().enumerate() { + cumul_prod[i % WIDTH] *= xi; + buf.push(cumul_prod[i % WIDTH]); } + debug_assert_eq!(buf.len(), n); - // At this stage buf contains the the cumulative product of x. We reuse the buffer for - // efficiency. At the end of the loop, it is filled with inverses of x. - let mut a_inv = cumul_prod.inverse(); - buf[n - 1] = buf[n - 2] * a_inv; - for i in (1..n - 1).rev() { - a_inv = x[i + 1] * a_inv; - // buf[i - 1] has not been written to by this loop, so it equals x[0] * ... x[n - 1]. - buf[i] = buf[i - 1] * a_inv; + let mut a_inv = { + // This is where the four dependency chains meet. + // Take the last four elements of buf and invert them all. + let c01 = cumul_prod[0] * cumul_prod[1]; + let c23 = cumul_prod[2] * cumul_prod[3]; + let c0123 = c01 * c23; + let c0123inv = c0123.inverse(); + let c01inv = c0123inv * c23; + let c23inv = c0123inv * c01; + [ + c01inv * cumul_prod[1], + c01inv * cumul_prod[0], + c23inv * cumul_prod[3], + c23inv * cumul_prod[2], + ] + }; + + for i in (WIDTH..n).rev() { + // buf[i - WIDTH] has not been written to by this loop, so it equals + // x[i % WIDTH] * x[i % WIDTH + WIDTH] * ... * x[i - WIDTH]. + buf[i] = buf[i - WIDTH] * a_inv[i % WIDTH]; // buf[i] now holds the inverse of x[i]. + a_inv[i % WIDTH] *= x[i]; } - buf[0] = x[1] * a_inv; + for i in (0..WIDTH).rev() { + buf[i] = a_inv[i]; + } + + for (&bi, &xi) in buf.iter().zip(x) { + // Sanity check only. + debug_assert_eq!(bi * xi, Self::ONE); + } + buf } From 2c06309cf7153d0d3db2ca2e06e96ac11c6359db Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 30 Nov 2021 17:12:13 +0100 Subject: [PATCH 135/202] Fix all clippy lints --- src/bin/generate_constants.rs | 2 ++ src/field/extension_field/target.rs | 1 + src/field/fft.rs | 18 +++++------ src/field/field_testing.rs | 1 + src/field/field_types.rs | 4 ++- src/field/prime_field_testing.rs | 2 +- src/fri/commitment.rs | 2 +- src/fri/recursive_verifier.rs | 2 +- src/gadgets/biguint.rs | 10 +++---- src/gadgets/permutation.rs | 2 +- src/gadgets/sorting.rs | 2 +- src/gates/arithmetic_u32.rs | 30 +++++++------------ src/gates/assert_le.rs | 10 +++---- src/gates/comparison.rs | 24 +++++++-------- src/gates/exponentiation.rs | 5 ++-- src/gates/insertion.rs | 10 +++---- src/gates/interpolation.rs | 2 +- src/gates/poseidon.rs | 8 ++--- src/gates/random_access.rs | 3 +- src/gates/subtraction_u32.rs | 16 ++++------ .../arch/aarch64/poseidon_goldilocks_neon.rs | 13 ++++---- src/hash/hashing.rs | 4 +-- src/iop/target.rs | 1 + src/lib.rs | 5 ++++ src/plonk/circuit_builder.rs | 4 +-- src/plonk/proof.rs | 6 +--- src/plonk/recursive_verifier.rs | 10 +++---- src/polynomial/polynomial.rs | 2 +- src/util/partial_products.rs | 4 +-- src/util/timing.rs | 2 +- 30 files changed, 92 insertions(+), 113 deletions(-) diff --git a/src/bin/generate_constants.rs b/src/bin/generate_constants.rs index 60028741..89630fc7 100644 --- a/src/bin/generate_constants.rs +++ b/src/bin/generate_constants.rs @@ -1,5 +1,7 @@ //! Generates random constants using ChaCha20, seeded with zero. +#![allow(clippy::needless_range_loop)] + use plonky2::field::field_types::PrimeField; use plonky2::field::goldilocks_field::GoldilocksField; use rand::{Rng, SeedableRng}; diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 3f9b6684..97e926b4 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -35,6 +35,7 @@ impl ExtensionTarget { let arr = self.to_target_array(); let k = (F::order() - 1u32) / (D as u64); let z0 = F::Extension::W.exp_biguint(&(k * count as u64)); + #[allow(clippy::needless_collect)] let zs = z0 .powers() .take(D) diff --git a/src/field/fft.rs b/src/field/fft.rs index 17c29184..89b6844d 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -38,12 +38,12 @@ fn fft_dispatch( zero_factor: Option, root_table: Option<&FftRootTable>, ) -> Vec { - let computed_root_table = if let Some(_) = root_table { + let computed_root_table = if root_table.is_some() { None } else { Some(fft_root_table(input.len())) }; - let used_root_table = root_table.or(computed_root_table.as_ref()).unwrap(); + let used_root_table = root_table.or_else(|| computed_root_table.as_ref()).unwrap(); fft_classic(input, zero_factor.unwrap_or(0), used_root_table) } @@ -122,8 +122,8 @@ fn fft_classic_simd( // Set omega to root_table[lg_half_m][0..half_m] but repeated. let mut omega_vec = P::zero().to_vec(); - for j in 0..omega_vec.len() { - omega_vec[j] = root_table[lg_half_m][j % half_m]; + for (j, omega) in omega_vec.iter_mut().enumerate() { + *omega = root_table[lg_half_m][j % half_m]; } let omega = P::from_slice(&omega_vec[..]); @@ -201,9 +201,9 @@ pub(crate) fn fft_classic(input: &[F], r: usize, root_table: &FftRootT if lg_n <= lg_packed_width { // Need the slice to be at least the width of two packed vectors for the vectorized version // to work. Do this tiny problem in scalar. - fft_classic_simd::>(&mut values[..], r, lg_n, &root_table); + fft_classic_simd::>(&mut values[..], r, lg_n, root_table); } else { - fft_classic_simd::<::PackedType>(&mut values[..], r, lg_n, &root_table); + fft_classic_simd::<::PackedType>(&mut values[..], r, lg_n, root_table); } values } @@ -267,7 +267,7 @@ mod tests { let values = subgroup .into_iter() - .map(|x| evaluate_at_naive(&coefficients, x)) + .map(|x| evaluate_at_naive(coefficients, x)) .collect(); PolynomialValues::new(values) } @@ -276,8 +276,8 @@ mod tests { let mut sum = F::ZERO; let mut point_power = F::ONE; for &c in &coefficients.coeffs { - sum = sum + c * point_power; - point_power = point_power * point; + sum += c * point_power; + point_power *= point; } sum } diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index d15b712a..54718f4a 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -1,3 +1,4 @@ +#![allow(clippy::eq_op)] use crate::field::extension_field::Extendable; use crate::field::extension_field::Frobenius; use crate::field::field_types::Field; diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 45839bd9..036793cc 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -335,7 +335,7 @@ pub trait Field: } fn kth_root_u64(&self, k: u64) -> Self { - let p = Self::order().clone(); + let p = Self::order(); let p_minus_1 = &p - 1u32; debug_assert!( Self::is_monomial_permutation_u64(k), @@ -422,6 +422,7 @@ pub trait PrimeField: Field { unsafe { self.sub_canonical_u64(1) } } + /// # Safety /// Equivalent to *self + Self::from_canonical_u64(rhs), but may be cheaper. The caller must /// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this /// precondition is not met. It is marked unsafe for this reason. @@ -431,6 +432,7 @@ pub trait PrimeField: Field { *self + Self::from_canonical_u64(rhs) } + /// # Safety /// Equivalent to *self - Self::from_canonical_u64(rhs), but may be cheaper. The caller must /// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this /// precondition is not met. It is marked unsafe for this reason. diff --git a/src/field/prime_field_testing.rs b/src/field/prime_field_testing.rs index 4febc3a8..9dae4896 100644 --- a/src/field/prime_field_testing.rs +++ b/src/field/prime_field_testing.rs @@ -24,7 +24,7 @@ where ExpectedOp: Fn(u64) -> u64, { let inputs = test_inputs(F::ORDER); - let expected: Vec<_> = inputs.iter().map(|x| expected_op(x.clone())).collect(); + let expected: Vec<_> = inputs.iter().map(|&x| expected_op(x)).collect(); let output: Vec<_> = inputs .iter() .cloned() diff --git a/src/fri/commitment.rs b/src/fri/commitment.rs index c8a13cac..ee2e66eb 100644 --- a/src/fri/commitment.rs +++ b/src/fri/commitment.rs @@ -216,7 +216,7 @@ impl PolynomialBatchCommitment { lde_final_poly, lde_final_values, challenger, - &common_data, + common_data, timing, ); diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index b31e01a8..91d0580a 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -407,7 +407,7 @@ impl, const D: usize> CircuitBuilder { arity_bits, evals, betas[i], - &common_data + common_data ) ); diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index fff97a6e..fa5aa0b6 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -142,9 +142,9 @@ impl, const D: usize> CircuitBuilder { let mut combined_limbs = vec![]; let mut carry = self.zero_u32(); - for i in 0..total_limbs { - to_add[i].push(carry); - let (new_result, new_carry) = self.add_many_u32(&to_add[i].clone()); + for summands in &mut to_add { + summands.push(carry); + let (new_result, new_carry) = self.add_many_u32(summands); combined_limbs.push(new_result); carry = new_carry; } @@ -172,9 +172,9 @@ impl, const D: usize> CircuitBuilder { _phantom: PhantomData, }); - let div_b = self.mul_biguint(&div, &b); + let div_b = self.mul_biguint(&div, b); let div_b_plus_rem = self.add_biguint(&div_b, &rem); - self.connect_biguint(&a, &div_b_plus_rem); + self.connect_biguint(a, &div_b_plus_rem); let cmp_rem_b = self.cmp_biguint(&rem, b); self.assert_one(cmp_rem_b.target); diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index 37169514..2d8c4f22 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -378,7 +378,7 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let lst: Vec = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect(); + let lst: Vec = (0..size * 2).map(F::from_canonical_usize).collect(); let a: Vec> = lst[..] .chunks(2) .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 2c52db23..2059a888 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -224,7 +224,7 @@ mod tests { izip!(is_write_vals, address_vals, timestamp_vals, value_vals) .zip(combined_vals_u64) .collect::>(); - input_ops_and_keys.sort_by_key(|(_, val)| val.clone()); + input_ops_and_keys.sort_by_key(|(_, val)| *val); let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(x, _)| x).collect(); let output_ops = diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs index 2bbbda6e..a5a63047 100644 --- a/src/gates/arithmetic_u32.rs +++ b/src/gates/arithmetic_u32.rs @@ -261,17 +261,11 @@ impl, const D: usize> SimpleGenerator fn dependencies(&self) -> Vec { let local_target = |input| Target::wire(self.gate_index, input); - let mut deps = Vec::with_capacity(3); - deps.push(local_target( - U32ArithmeticGate::::wire_ith_multiplicand_0(self.i), - )); - deps.push(local_target( - U32ArithmeticGate::::wire_ith_multiplicand_1(self.i), - )); - deps.push(local_target(U32ArithmeticGate::::wire_ith_addend( - self.i, - ))); - deps + vec![ + local_target(U32ArithmeticGate::::wire_ith_multiplicand_0(self.i)), + local_target(U32ArithmeticGate::::wire_ith_multiplicand_1(self.i)), + local_target(U32ArithmeticGate::::wire_ith_addend(self.i)), + ] } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { @@ -307,23 +301,19 @@ impl, const D: usize> SimpleGenerator let num_limbs = U32ArithmeticGate::::num_limbs(); let limb_base = 1 << U32ArithmeticGate::::limb_bits(); - let output_limbs_u64: Vec<_> = unfold((), move |_| { + let output_limbs_u64 = unfold((), move |_| { let ret = output_u64 % limb_base; output_u64 /= limb_base; Some(ret) }) - .take(num_limbs) - .collect(); - let output_limbs_f: Vec<_> = output_limbs_u64 - .into_iter() - .map(F::from_canonical_u64) - .collect(); + .take(num_limbs); + let output_limbs_f = output_limbs_u64.map(F::from_canonical_u64); - for j in 0..num_limbs { + for (j, output_limb) in output_limbs_f.enumerate() { let wire = local_wire(U32ArithmeticGate::::wire_ith_output_jth_limb( self.i, j, )); - out_buffer.set_wire(wire, output_limbs_f[j]); + out_buffer.set_wire(wire, output_limb); } } } diff --git a/src/gates/assert_le.rs b/src/gates/assert_le.rs index ffbc043a..4da3c44b 100644 --- a/src/gates/assert_le.rs +++ b/src/gates/assert_le.rs @@ -340,10 +340,10 @@ impl, const D: usize> SimpleGenerator fn dependencies(&self) -> Vec { let local_target = |input| Target::wire(self.gate_index, input); - let mut deps = Vec::new(); - deps.push(local_target(self.gate.wire_first_input())); - deps.push(local_target(self.gate.wire_second_input())); - deps + vec![ + local_target(self.gate.wire_first_input()), + local_target(self.gate.wire_second_input()), + ] } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { @@ -555,7 +555,7 @@ mod tests { }; let mut rng = rand::thread_rng(); - let max: u64 = 1 << num_bits - 1; + let max: u64 = 1 << (num_bits - 1); let first_input_u64 = rng.gen_range(0..max); let second_input_u64 = { let mut val = rng.gen_range(0..max); diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index a610c5e2..614a759e 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -151,10 +151,8 @@ impl, const D: usize> Gate for ComparisonGate .collect(); // Range-check the bits. - for i in 0..most_significant_diff_bits.len() { - constraints.push( - most_significant_diff_bits[i] * (F::Extension::ONE - most_significant_diff_bits[i]), - ); + for &bit in &most_significant_diff_bits { + constraints.push(bit * (F::Extension::ONE - bit)); } let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO); @@ -232,9 +230,8 @@ impl, const D: usize> Gate for ComparisonGate .collect(); // Range-check the bits. - for i in 0..most_significant_diff_bits.len() { - constraints - .push(most_significant_diff_bits[i] * (F::ONE - most_significant_diff_bits[i])); + for &bit in &most_significant_diff_bits { + constraints.push(bit * (F::ONE - bit)); } let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO); @@ -324,8 +321,7 @@ impl, const D: usize> Gate for ComparisonGate .collect(); // Range-check the bits. - for i in 0..most_significant_diff_bits.len() { - let this_bit = most_significant_diff_bits[i]; + for &this_bit in &most_significant_diff_bits { let inverse = builder.sub_extension(one, this_bit); constraints.push(builder.mul_extension(this_bit, inverse)); } @@ -388,10 +384,10 @@ impl, const D: usize> SimpleGenerator fn dependencies(&self) -> Vec { let local_target = |input| Target::wire(self.gate_index, input); - let mut deps = Vec::new(); - deps.push(local_target(self.gate.wire_first_input())); - deps.push(local_target(self.gate.wire_second_input())); - deps + vec![ + local_target(self.gate.wire_first_input()), + local_target(self.gate.wire_second_input()), + ] } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { @@ -638,7 +634,7 @@ mod tests { }; let mut rng = rand::thread_rng(); - let max: u64 = 1 << num_bits - 1; + let max: u64 = 1 << (num_bits - 1); let first_input_u64 = rng.gen_range(0..max); let second_input_u64 = { let mut val = rng.gen_range(0..max); diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index 5087cebd..47e6cce5 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -335,9 +335,8 @@ mod tests { .map(|b| F::from_canonical_u64(*b)) .collect(); - let mut v = Vec::new(); - v.push(base); - v.extend(power_bits_f.clone()); + let mut v = vec![base]; + v.extend(power_bits_f); let mut intermediate_values = Vec::new(); let mut current_intermediate_value = F::ONE; diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index c55f53a9..3d2c1aa3 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -251,8 +251,7 @@ impl, const D: usize> SimpleGenerator for Insert let local_targets = |inputs: Range| inputs.map(local_target); - let mut deps = Vec::new(); - deps.push(local_target(self.gate.wires_insertion_index())); + let mut deps = vec![local_target(self.gate.wires_insertion_index())]; deps.extend(local_targets(self.gate.wires_element_to_insert())); for i in 0..self.gate.vec_size { deps.extend(local_targets(self.gate.wires_original_list_item(i))); @@ -291,7 +290,7 @@ impl, const D: usize> SimpleGenerator for Insert vec_size ); - let mut new_vec = orig_vec.clone(); + let mut new_vec = orig_vec; new_vec.insert(insertion_index, to_insert); let mut equality_dummy_vals = Vec::new(); @@ -372,14 +371,13 @@ mod tests { fn get_wires(orig_vec: Vec, insertion_index: usize, element_to_insert: FF) -> Vec { let vec_size = orig_vec.len(); - let mut v = Vec::new(); - v.push(F::from_canonical_usize(insertion_index)); + let mut v = vec![F::from_canonical_usize(insertion_index)]; v.extend(element_to_insert.0); for j in 0..vec_size { v.extend(orig_vec[j].0); } - let mut new_vec = orig_vec.clone(); + let mut new_vec = orig_vec; new_vec.insert(insertion_index, element_to_insert); let mut equality_dummy_vals = Vec::new(); for i in 0..=vec_size { diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 0225ce59..3ebf7259 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -72,7 +72,7 @@ impl, const D: usize> HighDegreeInterpolationGate>::partial_first_constant_layer(&mut state); - state = >::mds_partial_layer_init(&mut state); + state = >::mds_partial_layer_init(&state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(state[0] - sbox_in); @@ -243,7 +243,7 @@ where // Partial rounds. >::partial_first_constant_layer(&mut state); - state = >::mds_partial_layer_init(&mut state); + state = >::mds_partial_layer_init(&state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(state[0] - sbox_in); @@ -345,7 +345,7 @@ where } } else { >::partial_first_constant_layer_recursive(builder, &mut state); - state = >::mds_partial_layer_init_recursive(builder, &mut state); + state = >::mds_partial_layer_init_recursive(builder, &state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(builder.sub_extension(state[0], sbox_in)); @@ -489,7 +489,7 @@ where } >::partial_first_constant_layer(&mut state); - state = >::mds_partial_layer_init(&mut state); + state = >::mds_partial_layer_init(&state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { out_buffer.set_wire( local_wire(PoseidonGate::::wire_partial_sbox(r)), diff --git a/src/gates/random_access.rs b/src/gates/random_access.rs index cea4b079..2796cd01 100644 --- a/src/gates/random_access.rs +++ b/src/gates/random_access.rs @@ -263,8 +263,7 @@ impl, const D: usize> SimpleGenerator fn dependencies(&self) -> Vec { let local_target = |input| Target::wire(self.gate_index, input); - let mut deps = Vec::new(); - deps.push(local_target(self.gate.wire_access_index(self.copy))); + let mut deps = vec![local_target(self.gate.wire_access_index(self.copy))]; for i in 0..self.gate.vec_size() { deps.push(local_target(self.gate.wire_list_item(i, self.copy))); } diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs index 225c09e4..26f6302e 100644 --- a/src/gates/subtraction_u32.rs +++ b/src/gates/subtraction_u32.rs @@ -244,17 +244,11 @@ impl, const D: usize> SimpleGenerator fn dependencies(&self) -> Vec { let local_target = |input| Target::wire(self.gate_index, input); - let mut deps = Vec::with_capacity(3); - deps.push(local_target(U32SubtractionGate::::wire_ith_input_x( - self.i, - ))); - deps.push(local_target(U32SubtractionGate::::wire_ith_input_y( - self.i, - ))); - deps.push(local_target( - U32SubtractionGate::::wire_ith_input_borrow(self.i), - )); - deps + vec![ + local_target(U32SubtractionGate::::wire_ith_input_x(self.i)), + local_target(U32SubtractionGate::::wire_ith_input_y(self.i)), + local_target(U32SubtractionGate::::wire_ith_input_borrow(self.i)), + ] } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { diff --git a/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs b/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs index 5d9d9fba..0aaa13a6 100644 --- a/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs +++ b/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs @@ -1,3 +1,5 @@ +#![allow(clippy::assertions_on_constants)] + use std::arch::aarch64::*; use static_assertions::const_assert; @@ -171,9 +173,7 @@ unsafe fn multiply(x: u64, y: u64) -> u64 { let xy_hi_lo_mul_epsilon = mul_epsilon(xy_hi); // add_with_wraparound is safe, as xy_hi_lo_mul_epsilon <= 0xfffffffe00000001 <= ORDER. - let res1 = add_with_wraparound(res0, xy_hi_lo_mul_epsilon); - - res1 + add_with_wraparound(res0, xy_hi_lo_mul_epsilon) } // ==================================== STANDALONE CONST LAYER ===================================== @@ -266,9 +266,7 @@ unsafe fn mds_reduce( // Multiply by EPSILON and accumulate. let res_unadj = vmlal_laneq_u32::<0>(res_lo, res_hi_hi, mds_consts0); let res_adj = vcgtq_u64(res_lo, res_unadj); - let res = vsraq_n_u64::<32>(res_unadj, res_adj); - - res + vsraq_n_u64::<32>(res_unadj, res_adj) } #[inline(always)] @@ -968,8 +966,7 @@ unsafe fn partial_round( #[inline(always)] unsafe fn full_round(state: [u64; 12], round_constants: &[u64; WIDTH]) -> [u64; 12] { let state = sbox_layer_full(state); - let state = mds_const_layers_full(state, round_constants); - state + mds_const_layers_full(state, round_constants) } #[inline] diff --git a/src/hash/hashing.rs b/src/hash/hashing.rs index a4610495..4fd537f1 100644 --- a/src/hash/hashing.rs +++ b/src/hash/hashing.rs @@ -110,9 +110,7 @@ pub fn hash_n_to_m(mut inputs: Vec, num_outputs: usize, pad: bo // Absorb all input chunks. for input_chunk in inputs.chunks(SPONGE_RATE) { - for i in 0..input_chunk.len() { - state[i] = input_chunk[i]; - } + state[..input_chunk.len()].copy_from_slice(input_chunk); state = permute(state); } diff --git a/src/iop/target.rs b/src/iop/target.rs index 8d4cbcfb..de3e4911 100644 --- a/src/iop/target.rs +++ b/src/iop/target.rs @@ -41,6 +41,7 @@ impl Target { /// A `Target` which has already been constrained such that it can only be 0 or 1. #[derive(Copy, Clone, Debug)] +#[allow(clippy::manual_non_exhaustive)] pub struct BoolTarget { pub target: Target, /// This private field is here to force all instantiations to go through `new_unsafe`. diff --git a/src/lib.rs b/src/lib.rs index 46db2cf5..331ee558 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,10 @@ #![allow(incomplete_features)] #![allow(const_evaluatable_unchecked)] +#![allow(clippy::new_without_default)] +#![allow(clippy::too_many_arguments)] +#![allow(clippy::len_without_is_empty)] +#![allow(clippy::module_inception)] +#![allow(clippy::needless_range_loop)] #![feature(asm)] #![feature(asm_sym)] #![feature(destructuring_assignment)] diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 730699b9..01462611 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -634,7 +634,7 @@ impl, const D: usize> CircuitBuilder { // Precompute FFT roots. let max_fft_points = - 1 << degree_bits + max(self.config.rate_bits, log2_ceil(quotient_degree_factor)); + 1 << (degree_bits + max(self.config.rate_bits, log2_ceil(quotient_degree_factor))); let fft_root_table = fft_root_table(max_fft_points); let constants_sigmas_vecs = [constant_vecs, sigma_vecs.clone()].concat(); @@ -669,7 +669,7 @@ impl, const D: usize> CircuitBuilder { let watch_rep_index = forest.parents[watch_index]; generator_indices_by_watches .entry(watch_rep_index) - .or_insert(vec![]) + .or_insert_with(Vec::new) .push(i); } } diff --git a/src/plonk/proof.rs b/src/plonk/proof.rs index 70e38588..94bc4714 100644 --- a/src/plonk/proof.rs +++ b/src/plonk/proof.rs @@ -138,11 +138,7 @@ impl, const D: usize> CompressedProof { plonk_zs_partial_products_cap, quotient_polys_cap, openings, - opening_proof: opening_proof.decompress( - &challenges, - fri_inferred_elements, - common_data, - ), + opening_proof: opening_proof.decompress(challenges, fri_inferred_elements, common_data), } } } diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 98c866f8..46ca5e6a 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -527,7 +527,7 @@ mod tests { &inner_vd.constants_sigmas_cap, ); - builder.add_recursive_verifier(pt, &inner_config, &inner_data, &inner_cd); + builder.add_recursive_verifier(pt, inner_config, &inner_data, &inner_cd); if print_gate_counts { builder.print_gate_counts(0); @@ -563,12 +563,12 @@ mod tests { ) -> Result<()> { let proof_bytes = proof.to_bytes()?; info!("Proof length: {} bytes", proof_bytes.len()); - let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, &cd)?; + let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, cd)?; assert_eq!(proof, &proof_from_bytes); let now = std::time::Instant::now(); - let compressed_proof = proof.clone().compress(&cd)?; - let decompressed_compressed_proof = compressed_proof.clone().decompress(&cd)?; + let compressed_proof = proof.clone().compress(cd)?; + let decompressed_compressed_proof = compressed_proof.clone().decompress(cd)?; info!("{:.4}s to compress proof", now.elapsed().as_secs_f64()); assert_eq!(proof, &decompressed_compressed_proof); @@ -578,7 +578,7 @@ mod tests { compressed_proof_bytes.len() ); let compressed_proof_from_bytes = - CompressedProofWithPublicInputs::from_bytes(compressed_proof_bytes, &cd)?; + CompressedProofWithPublicInputs::from_bytes(compressed_proof_bytes, cd)?; assert_eq!(compressed_proof, compressed_proof_from_bytes); Ok(()) diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index b5a7fabd..43e17823 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -441,7 +441,7 @@ mod tests { assert_eq!(coset_evals, naive_coset_evals); let ifft_coeffs = PolynomialValues::new(coset_evals).coset_ifft(shift); - assert_eq!(poly, ifft_coeffs.into()); + assert_eq!(poly, ifft_coeffs); } #[test] diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 0f3c9bfa..b5e805e9 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -13,7 +13,7 @@ pub(crate) fn quotient_chunk_products( max_degree: usize, ) -> Vec { debug_assert!(max_degree > 1); - assert!(quotient_values.len() > 0); + assert!(!quotient_values.is_empty()); let chunk_size = max_degree; quotient_values .chunks(chunk_size) @@ -24,7 +24,7 @@ pub(crate) fn quotient_chunk_products( /// Compute partial products of the original vector `v` such that all products consist of `max_degree` /// or less elements. This is done until we've computed the product `P` of all elements in the vector. pub(crate) fn partial_products_and_z_gx(z_x: F, quotient_chunk_products: &[F]) -> Vec { - assert!(quotient_chunk_products.len() > 0); + assert!(!quotient_chunk_products.is_empty()); let mut res = Vec::new(); let mut acc = z_x; for "ient_chunk_product in quotient_chunk_products { diff --git a/src/util/timing.rs b/src/util/timing.rs index cd9ea731..4250d688 100644 --- a/src/util/timing.rs +++ b/src/util/timing.rs @@ -92,7 +92,7 @@ impl TimingTree { fn duration(&self) -> Duration { self.exit_time - .unwrap_or(Instant::now()) + .unwrap_or_else(Instant::now) .duration_since(self.enter_time) } From 7097081e5b912c535d1653efaf16808061237d1c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 30 Nov 2021 17:28:29 +0100 Subject: [PATCH 136/202] Add clippy to CI --- .github/workflows/continuous-integration-workflow.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/continuous-integration-workflow.yml b/.github/workflows/continuous-integration-workflow.yml index dd79f33e..6ff07104 100644 --- a/.github/workflows/continuous-integration-workflow.yml +++ b/.github/workflows/continuous-integration-workflow.yml @@ -50,3 +50,10 @@ jobs: with: command: fmt args: --all -- --check + + - name: Run cargo clippy + uses: actions-rs/cargo@v1 + with: + command: clippy + args: --all-features --all-targets -- -D warnings + From b3d246a7c5254b0b6754f97a2d84b11b81e49f6c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 30 Nov 2021 17:55:39 +0100 Subject: [PATCH 137/202] Minor --- .github/workflows/continuous-integration-workflow.yml | 2 +- src/lib.rs | 1 - src/polynomial/mod.rs | 2 ++ 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/continuous-integration-workflow.yml b/.github/workflows/continuous-integration-workflow.yml index 6ff07104..046ade32 100644 --- a/.github/workflows/continuous-integration-workflow.yml +++ b/.github/workflows/continuous-integration-workflow.yml @@ -55,5 +55,5 @@ jobs: uses: actions-rs/cargo@v1 with: command: clippy - args: --all-features --all-targets -- -D warnings + args: --all-features --all-targets -- -D warnings -A incomplete-features diff --git a/src/lib.rs b/src/lib.rs index 331ee558..e76e312c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,6 @@ #![allow(clippy::new_without_default)] #![allow(clippy::too_many_arguments)] #![allow(clippy::len_without_is_empty)] -#![allow(clippy::module_inception)] #![allow(clippy::needless_range_loop)] #![feature(asm)] #![feature(asm_sym)] diff --git a/src/polynomial/mod.rs b/src/polynomial/mod.rs index 2c7f7076..2798b1f2 100644 --- a/src/polynomial/mod.rs +++ b/src/polynomial/mod.rs @@ -1,2 +1,4 @@ +#![allow(clippy::module_inception)] + pub(crate) mod division; pub mod polynomial; From 915f4eccc5b7adebea025287282adb4bf40553e3 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 30 Nov 2021 18:09:58 +0100 Subject: [PATCH 138/202] Fix github CI --- .github/workflows/continuous-integration-workflow.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/continuous-integration-workflow.yml b/.github/workflows/continuous-integration-workflow.yml index 046ade32..4cb70602 100644 --- a/.github/workflows/continuous-integration-workflow.yml +++ b/.github/workflows/continuous-integration-workflow.yml @@ -30,7 +30,7 @@ jobs: args: --all lints: - name: Formatting + name: Formatting and Clippy runs-on: ubuntu-latest if: "! contains(toJSON(github.event.commits.*.message), '[skip-ci]')" steps: @@ -43,7 +43,7 @@ jobs: profile: minimal toolchain: nightly override: true - components: rustfmt + components: rustfmt, clippy - name: Run cargo fmt uses: actions-rs/cargo@v1 From 301edf3ab581b04d35a0dc101ef61a9e5d0359cd Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 30 Nov 2021 18:18:56 +0100 Subject: [PATCH 139/202] Move clippy::eq_ip --- src/field/field_testing.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index 54718f4a..767a3cf2 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -1,4 +1,3 @@ -#![allow(clippy::eq_op)] use crate::field::extension_field::Extendable; use crate::field::extension_field::Frobenius; use crate::field::field_types::Field; @@ -89,6 +88,7 @@ macro_rules! test_field_arithmetic { }; } +#[allow(clippy::eq_op)] pub(crate) fn test_add_neg_sub_mul, const D: usize>() { let x = BF::Extension::rand(); let y = BF::Extension::rand(); From a0b0a2d715f9a472133b5386aab908fdee9c73dd Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 30 Nov 2021 20:17:34 +0100 Subject: [PATCH 140/202] Move polynomial.rs to mod.rs --- benches/ffts.rs | 2 +- src/bin/bench_ldes.rs | 2 +- src/field/fft.rs | 4 +- src/field/interpolation.rs | 4 +- src/fri/commitment.rs | 2 +- src/fri/proof.rs | 2 +- src/fri/prover.rs | 2 +- src/gates/gate_testing.rs | 2 +- src/gates/interpolation.rs | 4 +- src/gates/low_degree_interpolation.rs | 4 +- src/plonk/circuit_builder.rs | 2 +- src/plonk/get_challenges.rs | 2 +- src/plonk/permutation_argument.rs | 2 +- src/plonk/prover.rs | 2 +- src/polynomial/division.rs | 4 +- src/polynomial/mod.rs | 618 +++++++++++++++++++++++++- src/util/mod.rs | 2 +- src/util/reducing.rs | 2 +- src/util/serialization.rs | 2 +- 19 files changed, 638 insertions(+), 26 deletions(-) diff --git a/benches/ffts.rs b/benches/ffts.rs index 8492cfe9..745d53a8 100644 --- a/benches/ffts.rs +++ b/benches/ffts.rs @@ -1,7 +1,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use plonky2::field::field_types::Field; use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::polynomial::polynomial::PolynomialCoeffs; +use plonky2::polynomial::PolynomialCoeffs; use tynm::type_name; pub(crate) fn bench_ffts(c: &mut Criterion) { diff --git a/src/bin/bench_ldes.rs b/src/bin/bench_ldes.rs index d121831b..dbcfa6df 100644 --- a/src/bin/bench_ldes.rs +++ b/src/bin/bench_ldes.rs @@ -2,7 +2,7 @@ use std::time::Instant; use plonky2::field::field_types::Field; use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::polynomial::polynomial::PolynomialValues; +use plonky2::polynomial::PolynomialValues; use rayon::prelude::*; type F = GoldilocksField; diff --git a/src/field/fft.rs b/src/field/fft.rs index 89b6844d..6f5155a4 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -6,7 +6,7 @@ use unroll::unroll_for_loops; use crate::field::field_types::Field; use crate::field::packable::Packable; use crate::field::packed_field::{PackedField, Singleton}; -use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::{log2_strict, reverse_index_bits}; pub(crate) type FftRootTable = Vec>; @@ -213,7 +213,7 @@ mod tests { use crate::field::fft::{fft, fft_with_options, ifft}; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; - use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; + use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::{log2_ceil, log2_strict}; #[test] diff --git a/src/field/interpolation.rs b/src/field/interpolation.rs index c4a49fe1..ad3ddf72 100644 --- a/src/field/interpolation.rs +++ b/src/field/interpolation.rs @@ -1,6 +1,6 @@ use crate::field::fft::ifft; use crate::field::field_types::Field; -use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::log2_ceil; /// Computes the unique degree < n interpolant of an arbitrary list of n (point, value) pairs. @@ -80,7 +80,7 @@ mod tests { use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; - use crate::polynomial::polynomial::PolynomialCoeffs; + use crate::polynomial::PolynomialCoeffs; #[test] fn interpolant_random() { diff --git a/src/fri/commitment.rs b/src/fri/commitment.rs index ee2e66eb..c8506356 100644 --- a/src/fri/commitment.rs +++ b/src/fri/commitment.rs @@ -10,7 +10,7 @@ use crate::iop::challenger::Challenger; use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::plonk_common::PlonkPolynomials; use crate::plonk::proof::OpeningSet; -use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; use crate::util::reducing::ReducingFactor; use crate::util::timing::TimingTree; diff --git a/src/fri/proof.rs b/src/fri/proof.rs index f6875fcc..72cf594a 100644 --- a/src/fri/proof.rs +++ b/src/fri/proof.rs @@ -15,7 +15,7 @@ use crate::iop::target::Target; use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::plonk_common::PolynomialsIndexBlinding; use crate::plonk::proof::{FriInferredElements, ProofChallenges}; -use crate::polynomial::polynomial::PolynomialCoeffs; +use crate::polynomial::PolynomialCoeffs; /// Evaluations and Merkle proof produced by the prover in a FRI query step. #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] diff --git a/src/fri/prover.rs b/src/fri/prover.rs index ec58fc14..90ef8cfe 100644 --- a/src/fri/prover.rs +++ b/src/fri/prover.rs @@ -10,7 +10,7 @@ use crate::hash::merkle_tree::MerkleTree; use crate::iop::challenger::Challenger; use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::plonk_common::reduce_with_powers; -use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; use crate::util::reverse_index_bits_in_place; use crate::util::timing::TimingTree; diff --git a/src/gates/gate_testing.rs b/src/gates/gate_testing.rs index 9fe8e835..3fceded4 100644 --- a/src/gates/gate_testing.rs +++ b/src/gates/gate_testing.rs @@ -9,7 +9,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::plonk::verifier::verify; -use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::{log2_ceil, transpose}; const WITNESS_SIZE: usize = 1 << 5; diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 3ebf7259..d2142ee1 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -15,7 +15,7 @@ use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -use crate::polynomial::polynomial::PolynomialCoeffs; +use crate::polynomial::PolynomialCoeffs; /// Interpolation gate with constraints of degree at most `1<, const D: usize>( diff --git a/src/plonk/permutation_argument.rs b/src/plonk/permutation_argument.rs index 28a07dff..ca3977ce 100644 --- a/src/plonk/permutation_argument.rs +++ b/src/plonk/permutation_argument.rs @@ -5,7 +5,7 @@ use rayon::prelude::*; use crate::field::field_types::Field; use crate::iop::target::Target; use crate::iop::wire::Wire; -use crate::polynomial::polynomial::PolynomialValues; +use crate::polynomial::PolynomialValues; /// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure. pub struct Forest { diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 3f8e607d..4f281f06 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -17,7 +17,7 @@ use crate::plonk::plonk_common::ZeroPolyOnCoset; use crate::plonk::proof::{Proof, ProofWithPublicInputs}; use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch; use crate::plonk::vars::EvaluationVarsBase; -use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_products}; use crate::util::timing::TimingTree; diff --git a/src/polynomial/division.rs b/src/polynomial/division.rs index b5dad629..671b7715 100644 --- a/src/polynomial/division.rs +++ b/src/polynomial/division.rs @@ -1,5 +1,5 @@ use crate::field::field_types::Field; -use crate::polynomial::polynomial::PolynomialCoeffs; +use crate::polynomial::PolynomialCoeffs; use crate::util::log2_ceil; impl PolynomialCoeffs { @@ -129,7 +129,7 @@ mod tests { use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; - use crate::polynomial::polynomial::PolynomialCoeffs; + use crate::polynomial::PolynomialCoeffs; #[test] #[ignore] diff --git a/src/polynomial/mod.rs b/src/polynomial/mod.rs index 2798b1f2..1a7b90fe 100644 --- a/src/polynomial/mod.rs +++ b/src/polynomial/mod.rs @@ -1,4 +1,616 @@ -#![allow(clippy::module_inception)] - pub(crate) mod division; -pub mod polynomial; + +use std::cmp::max; +use std::iter::Sum; +use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; + +use anyhow::{ensure, Result}; +use serde::{Deserialize, Serialize}; + +use crate::field::extension_field::{Extendable, FieldExtension}; +use crate::field::fft::{fft, fft_with_options, ifft, FftRootTable}; +use crate::field::field_types::Field; +use crate::util::log2_strict; + +/// A polynomial in point-value form. +/// +/// The points are implicitly `g^i`, where `g` generates the subgroup whose size equals the number +/// of points. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct PolynomialValues { + pub values: Vec, +} + +impl PolynomialValues { + pub fn new(values: Vec) -> Self { + PolynomialValues { values } + } + + /// The number of values stored. + pub(crate) fn len(&self) -> usize { + self.values.len() + } + + pub fn ifft(&self) -> PolynomialCoeffs { + ifft(self) + } + + /// Returns the polynomial whose evaluation on the coset `shift*H` is `self`. + pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs { + let mut shifted_coeffs = self.ifft(); + shifted_coeffs + .coeffs + .iter_mut() + .zip(shift.inverse().powers()) + .for_each(|(c, r)| { + *c *= r; + }); + shifted_coeffs + } + + pub fn lde_multiple(polys: Vec, rate_bits: usize) -> Vec { + polys.into_iter().map(|p| p.lde(rate_bits)).collect() + } + + pub fn lde(&self, rate_bits: usize) -> Self { + let coeffs = ifft(self).lde(rate_bits); + fft_with_options(&coeffs, Some(rate_bits), None) + } + + pub fn degree(&self) -> usize { + self.degree_plus_one() + .checked_sub(1) + .expect("deg(0) is undefined") + } + + pub fn degree_plus_one(&self) -> usize { + self.ifft().degree_plus_one() + } +} + +impl From> for PolynomialValues { + fn from(values: Vec) -> Self { + Self::new(values) + } +} + +/// A polynomial in coefficient form. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct PolynomialCoeffs { + pub(crate) coeffs: Vec, +} + +impl PolynomialCoeffs { + pub fn new(coeffs: Vec) -> Self { + PolynomialCoeffs { coeffs } + } + + pub(crate) fn empty() -> Self { + Self::new(Vec::new()) + } + + pub(crate) fn zero(len: usize) -> Self { + Self::new(vec![F::ZERO; len]) + } + + pub(crate) fn is_zero(&self) -> bool { + self.coeffs.iter().all(|x| x.is_zero()) + } + + /// The number of coefficients. This does not filter out any zero coefficients, so it is not + /// necessarily related to the degree. + pub fn len(&self) -> usize { + self.coeffs.len() + } + + pub fn log_len(&self) -> usize { + log2_strict(self.len()) + } + + pub(crate) fn chunks(&self, chunk_size: usize) -> Vec { + self.coeffs + .chunks(chunk_size) + .map(|chunk| PolynomialCoeffs::new(chunk.to_vec())) + .collect() + } + + pub fn eval(&self, x: F) -> F { + self.coeffs + .iter() + .rev() + .fold(F::ZERO, |acc, &c| acc * x + c) + } + + /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. + pub fn eval_with_powers(&self, powers: &[F]) -> F { + debug_assert_eq!(self.coeffs.len(), powers.len() + 1); + let acc = self.coeffs[0]; + self.coeffs[1..] + .iter() + .zip(powers) + .fold(acc, |acc, (&x, &c)| acc + c * x) + } + + pub fn eval_base(&self, x: F::BaseField) -> F + where + F: FieldExtension, + { + self.coeffs + .iter() + .rev() + .fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c) + } + + /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. + pub fn eval_base_with_powers(&self, powers: &[F::BaseField]) -> F + where + F: FieldExtension, + { + debug_assert_eq!(self.coeffs.len(), powers.len() + 1); + let acc = self.coeffs[0]; + self.coeffs[1..] + .iter() + .zip(powers) + .fold(acc, |acc, (&x, &c)| acc + x.scalar_mul(c)) + } + + pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec { + polys.into_iter().map(|p| p.lde(rate_bits)).collect() + } + + pub fn lde(&self, rate_bits: usize) -> Self { + self.padded(self.len() << rate_bits) + } + + pub(crate) fn pad(&mut self, new_len: usize) -> Result<()> { + ensure!( + new_len >= self.len(), + "Trying to pad a polynomial of length {} to a length of {}.", + self.len(), + new_len + ); + self.coeffs.resize(new_len, F::ZERO); + Ok(()) + } + + pub(crate) fn padded(&self, new_len: usize) -> Self { + let mut poly = self.clone(); + poly.pad(new_len).unwrap(); + poly + } + + /// Removes leading zero coefficients. + pub fn trim(&mut self) { + self.coeffs.truncate(self.degree_plus_one()); + } + + /// Removes leading zero coefficients. + pub fn trimmed(&self) -> Self { + let coeffs = self.coeffs[..self.degree_plus_one()].to_vec(); + Self { coeffs } + } + + /// Degree of the polynomial + 1, or 0 for a polynomial with no non-zero coefficients. + pub(crate) fn degree_plus_one(&self) -> usize { + (0usize..self.len()) + .rev() + .find(|&i| self.coeffs[i].is_nonzero()) + .map_or(0, |i| i + 1) + } + + /// Leading coefficient. + pub fn lead(&self) -> F { + self.coeffs + .iter() + .rev() + .find(|x| x.is_nonzero()) + .map_or(F::ZERO, |x| *x) + } + + /// Reverse the order of the coefficients, not taking into account the leading zero coefficients. + pub(crate) fn rev(&self) -> Self { + Self::new(self.trimmed().coeffs.into_iter().rev().collect()) + } + + pub fn fft(&self) -> PolynomialValues { + fft(self) + } + + pub fn fft_with_options( + &self, + zero_factor: Option, + root_table: Option<&FftRootTable>, + ) -> PolynomialValues { + fft_with_options(self, zero_factor, root_table) + } + + /// Returns the evaluation of the polynomial on the coset `shift*H`. + pub fn coset_fft(&self, shift: F) -> PolynomialValues { + self.coset_fft_with_options(shift, None, None) + } + + /// Returns the evaluation of the polynomial on the coset `shift*H`. + pub fn coset_fft_with_options( + &self, + shift: F, + zero_factor: Option, + root_table: Option<&FftRootTable>, + ) -> PolynomialValues { + let modified_poly: Self = shift + .powers() + .zip(&self.coeffs) + .map(|(r, &c)| r * c) + .collect::>() + .into(); + modified_poly.fft_with_options(zero_factor, root_table) + } + + pub fn to_extension(&self) -> PolynomialCoeffs + where + F: Extendable, + { + PolynomialCoeffs::new(self.coeffs.iter().map(|&c| c.into()).collect()) + } + + pub fn mul_extension(&self, rhs: F::Extension) -> PolynomialCoeffs + where + F: Extendable, + { + PolynomialCoeffs::new(self.coeffs.iter().map(|&c| rhs.scalar_mul(c)).collect()) + } +} + +impl PartialEq for PolynomialCoeffs { + fn eq(&self, other: &Self) -> bool { + let max_terms = self.coeffs.len().max(other.coeffs.len()); + for i in 0..max_terms { + let self_i = self.coeffs.get(i).cloned().unwrap_or(F::ZERO); + let other_i = other.coeffs.get(i).cloned().unwrap_or(F::ZERO); + if self_i != other_i { + return false; + } + } + true + } +} + +impl Eq for PolynomialCoeffs {} + +impl From> for PolynomialCoeffs { + fn from(coeffs: Vec) -> Self { + Self::new(coeffs) + } +} + +impl Add for &PolynomialCoeffs { + type Output = PolynomialCoeffs; + + fn add(self, rhs: Self) -> Self::Output { + let len = max(self.len(), rhs.len()); + let a = self.padded(len).coeffs; + let b = rhs.padded(len).coeffs; + let coeffs = a.into_iter().zip(b).map(|(x, y)| x + y).collect(); + PolynomialCoeffs::new(coeffs) + } +} + +impl Sum for PolynomialCoeffs { + fn sum>(iter: I) -> Self { + iter.fold(Self::empty(), |acc, p| &acc + &p) + } +} + +impl Sub for &PolynomialCoeffs { + type Output = PolynomialCoeffs; + + fn sub(self, rhs: Self) -> Self::Output { + let len = max(self.len(), rhs.len()); + let mut coeffs = self.padded(len).coeffs; + for (i, &c) in rhs.coeffs.iter().enumerate() { + coeffs[i] -= c; + } + PolynomialCoeffs::new(coeffs) + } +} + +impl AddAssign for PolynomialCoeffs { + fn add_assign(&mut self, rhs: Self) { + let len = max(self.len(), rhs.len()); + self.coeffs.resize(len, F::ZERO); + for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) { + *l += r; + } + } +} + +impl AddAssign<&Self> for PolynomialCoeffs { + fn add_assign(&mut self, rhs: &Self) { + let len = max(self.len(), rhs.len()); + self.coeffs.resize(len, F::ZERO); + for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) { + *l += r; + } + } +} + +impl SubAssign for PolynomialCoeffs { + fn sub_assign(&mut self, rhs: Self) { + let len = max(self.len(), rhs.len()); + self.coeffs.resize(len, F::ZERO); + for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) { + *l -= r; + } + } +} + +impl SubAssign<&Self> for PolynomialCoeffs { + fn sub_assign(&mut self, rhs: &Self) { + let len = max(self.len(), rhs.len()); + self.coeffs.resize(len, F::ZERO); + for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) { + *l -= r; + } + } +} + +impl Mul for &PolynomialCoeffs { + type Output = PolynomialCoeffs; + + fn mul(self, rhs: F) -> Self::Output { + let coeffs = self.coeffs.iter().map(|&x| rhs * x).collect(); + PolynomialCoeffs::new(coeffs) + } +} + +impl MulAssign for PolynomialCoeffs { + fn mul_assign(&mut self, rhs: F) { + self.coeffs.iter_mut().for_each(|x| *x *= rhs); + } +} + +impl Mul for &PolynomialCoeffs { + type Output = PolynomialCoeffs; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn mul(self, rhs: Self) -> Self::Output { + let new_len = (self.len() + rhs.len()).next_power_of_two(); + let a = self.padded(new_len); + let b = rhs.padded(new_len); + let a_evals = a.fft(); + let b_evals = b.fft(); + + let mul_evals: Vec = a_evals + .values + .into_iter() + .zip(b_evals.values) + .map(|(pa, pb)| pa * pb) + .collect(); + ifft(&mul_evals.into()) + } +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use rand::{thread_rng, Rng}; + + use super::*; + use crate::field::goldilocks_field::GoldilocksField; + + #[test] + fn test_trimmed() { + type F = GoldilocksField; + + assert_eq!( + PolynomialCoeffs:: { coeffs: vec![] }.trimmed(), + PolynomialCoeffs:: { coeffs: vec![] } + ); + assert_eq!( + PolynomialCoeffs:: { + coeffs: vec![F::ZERO] + } + .trimmed(), + PolynomialCoeffs:: { coeffs: vec![] } + ); + assert_eq!( + PolynomialCoeffs:: { + coeffs: vec![F::ONE, F::TWO, F::ZERO, F::ZERO] + } + .trimmed(), + PolynomialCoeffs:: { + coeffs: vec![F::ONE, F::TWO] + } + ); + } + + #[test] + fn test_coset_fft() { + type F = GoldilocksField; + + let k = 8; + let n = 1 << k; + let poly = PolynomialCoeffs::new(F::rand_vec(n)); + let shift = F::rand(); + let coset_evals = poly.coset_fft(shift).values; + + let generator = F::primitive_root_of_unity(k); + let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) + .into_iter() + .map(|x| poly.eval(x)) + .collect::>(); + assert_eq!(coset_evals, naive_coset_evals); + + let ifft_coeffs = PolynomialValues::new(coset_evals).coset_ifft(shift); + assert_eq!(poly, ifft_coeffs); + } + + #[test] + fn test_coset_ifft() { + type F = GoldilocksField; + + let k = 8; + let n = 1 << k; + let evals = PolynomialValues::new(F::rand_vec(n)); + let shift = F::rand(); + let coeffs = evals.coset_ifft(shift); + + let generator = F::primitive_root_of_unity(k); + let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) + .into_iter() + .map(|x| coeffs.eval(x)) + .collect::>(); + assert_eq!(evals, naive_coset_evals.into()); + + let fft_evals = coeffs.coset_fft(shift); + assert_eq!(evals, fft_evals); + } + + #[test] + fn test_polynomial_multiplication() { + type F = GoldilocksField; + let mut rng = thread_rng(); + let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000)); + let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); + let b = PolynomialCoeffs::new(F::rand_vec(b_deg)); + let m1 = &a * &b; + let m2 = &a * &b; + for _ in 0..1000 { + let x = F::rand(); + assert_eq!(m1.eval(x), a.eval(x) * b.eval(x)); + assert_eq!(m2.eval(x), a.eval(x) * b.eval(x)); + } + } + + #[test] + fn test_inv_mod_xn() { + type F = GoldilocksField; + let mut rng = thread_rng(); + let a_deg = rng.gen_range(1..1_000); + let n = rng.gen_range(1..1_000); + let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); + let b = a.inv_mod_xn(n); + let mut m = &a * &b; + m.coeffs.drain(n..); + m.trim(); + assert_eq!( + m, + PolynomialCoeffs::new(vec![F::ONE]), + "a: {:#?}, b:{:#?}, n:{:#?}, m:{:#?}", + a, + b, + n, + m + ); + } + + #[test] + fn test_polynomial_long_division() { + type F = GoldilocksField; + let mut rng = thread_rng(); + let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000)); + let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); + let b = PolynomialCoeffs::new(F::rand_vec(b_deg)); + let (q, r) = a.div_rem_long_division(&b); + for _ in 0..1000 { + let x = F::rand(); + assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x)); + } + } + + #[test] + fn test_polynomial_division() { + type F = GoldilocksField; + let mut rng = thread_rng(); + let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000)); + let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); + let b = PolynomialCoeffs::new(F::rand_vec(b_deg)); + let (q, r) = a.div_rem(&b); + for _ in 0..1000 { + let x = F::rand(); + assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x)); + } + } + + #[test] + fn test_polynomial_division_by_constant() { + type F = GoldilocksField; + let mut rng = thread_rng(); + let a_deg = rng.gen_range(1..10_000); + let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); + let b = PolynomialCoeffs::from(vec![F::rand()]); + let (q, r) = a.div_rem(&b); + for _ in 0..1000 { + let x = F::rand(); + assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x)); + } + } + + // Test to see which polynomial division method is faster for divisions of the type + // `(X^n - 1)/(X - a) + #[test] + fn test_division_linear() { + type F = GoldilocksField; + let mut rng = thread_rng(); + let l = 14; + let n = 1 << l; + let g = F::primitive_root_of_unity(l); + let xn_minus_one = { + let mut xn_min_one_vec = vec![F::ZERO; n + 1]; + xn_min_one_vec[n] = F::ONE; + xn_min_one_vec[0] = F::NEG_ONE; + PolynomialCoeffs::new(xn_min_one_vec) + }; + + let a = g.exp_u64(rng.gen_range(0..(n as u64))); + let denom = PolynomialCoeffs::new(vec![-a, F::ONE]); + let now = Instant::now(); + xn_minus_one.div_rem(&denom); + println!("Division time: {:?}", now.elapsed()); + let now = Instant::now(); + xn_minus_one.div_rem_long_division(&denom); + println!("Division time: {:?}", now.elapsed()); + } + + #[test] + fn eq() { + type F = GoldilocksField; + assert_eq!( + PolynomialCoeffs::::new(vec![]), + PolynomialCoeffs::new(vec![]) + ); + assert_eq!( + PolynomialCoeffs::::new(vec![F::ZERO]), + PolynomialCoeffs::new(vec![F::ZERO]) + ); + assert_eq!( + PolynomialCoeffs::::new(vec![]), + PolynomialCoeffs::new(vec![F::ZERO]) + ); + assert_eq!( + PolynomialCoeffs::::new(vec![F::ZERO]), + PolynomialCoeffs::new(vec![]) + ); + assert_eq!( + PolynomialCoeffs::::new(vec![F::ZERO]), + PolynomialCoeffs::new(vec![F::ZERO, F::ZERO]) + ); + assert_eq!( + PolynomialCoeffs::::new(vec![F::ONE]), + PolynomialCoeffs::new(vec![F::ONE, F::ZERO]) + ); + assert_ne!( + PolynomialCoeffs::::new(vec![]), + PolynomialCoeffs::new(vec![F::ONE]) + ); + assert_ne!( + PolynomialCoeffs::::new(vec![F::ZERO]), + PolynomialCoeffs::new(vec![F::ZERO, F::ONE]) + ); + assert_ne!( + PolynomialCoeffs::::new(vec![F::ZERO]), + PolynomialCoeffs::new(vec![F::ONE, F::ZERO]) + ); + } +} diff --git a/src/util/mod.rs b/src/util/mod.rs index 586033be..fca6b728 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,7 +1,7 @@ use core::hint::unreachable_unchecked; use crate::field::field_types::Field; -use crate::polynomial::polynomial::PolynomialValues; +use crate::polynomial::PolynomialValues; pub(crate) mod bimap; pub(crate) mod context_tree; diff --git a/src/util/reducing.rs b/src/util/reducing.rs index f700a6ff..4fc15690 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -8,7 +8,7 @@ use crate::gates::reducing::ReducingGate; use crate::gates::reducing_extension::ReducingExtensionGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::polynomial::polynomial::PolynomialCoeffs; +use crate::polynomial::PolynomialCoeffs; /// When verifying the composition polynomial in FRI we have to compute sums of the form /// `(sum_0^k a^i * x_i)/d_0 + (sum_k^r a^i * y_i)/d_1` diff --git a/src/util/serialization.rs b/src/util/serialization.rs index 5ca4f691..b3a51b5f 100644 --- a/src/util/serialization.rs +++ b/src/util/serialization.rs @@ -15,7 +15,7 @@ use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::proof::{ CompressedProof, CompressedProofWithPublicInputs, OpeningSet, Proof, ProofWithPublicInputs, }; -use crate::polynomial::polynomial::PolynomialCoeffs; +use crate::polynomial::PolynomialCoeffs; #[derive(Debug)] pub struct Buffer(Cursor>); From f29b591d49ba693078016392572cbf1e1808d88a Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 11:42:06 -0800 Subject: [PATCH 141/202] merge --- src/gadgets/nonnative.rs | 52 ++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 84691421..bc090cd5 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -7,54 +7,54 @@ use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gadgets::biguint::BigUintTarget; use crate::plonk::circuit_builder::CircuitBuilder; -pub struct ForeignFieldTarget { +pub struct NonNativeTarget { value: BigUintTarget, _phantom: PhantomData, } impl, const D: usize> CircuitBuilder { - pub fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { - ForeignFieldTarget { + pub fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget { + NonNativeTarget { value: x.clone(), _phantom: PhantomData, } } - pub fn nonnative_to_biguint(&mut self, x: &ForeignFieldTarget) -> BigUintTarget { + pub fn nonnative_to_biguint(&mut self, x: &NonNativeTarget) -> BigUintTarget { x.value.clone() } - pub fn constant_nonnative(&mut self, x: FF) -> ForeignFieldTarget { + pub fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget { let x_biguint = self.constant_biguint(&x.to_biguint()); self.biguint_to_nonnative(&x_biguint) } - // Assert that two ForeignFieldTarget's, both assumed to be in reduced form, are equal. + // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. pub fn connect_nonnative( &mut self, - lhs: &ForeignFieldTarget, - rhs: &ForeignFieldTarget, + lhs: &NonNativeTarget, + rhs: &NonNativeTarget, ) { self.connect_biguint(&lhs.value, &rhs.value); } - // Add two `ForeignFieldTarget`s. + // Add two `NonNativeTarget`s. pub fn add_nonnative( &mut self, - a: &ForeignFieldTarget, - b: &ForeignFieldTarget, - ) -> ForeignFieldTarget { + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget { let result = self.add_biguint(&a.value, &b.value); self.reduce(&result) } - // Subtract two `ForeignFieldTarget`s. + // Subtract two `NonNativeTarget`s. pub fn sub_nonnative( &mut self, - a: &ForeignFieldTarget, - b: &ForeignFieldTarget, - ) -> ForeignFieldTarget { + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget { let order = self.constant_biguint(&FF::order()); let a_plus_order = self.add_biguint(&order, &a.value); let result = self.sub_biguint(&a_plus_order, &b.value); @@ -65,9 +65,9 @@ impl, const D: usize> CircuitBuilder { pub fn mul_nonnative( &mut self, - a: &ForeignFieldTarget, - b: &ForeignFieldTarget, - ) -> ForeignFieldTarget { + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget { let result = self.mul_biguint(&a.value, &b.value); self.reduce(&result) @@ -75,8 +75,8 @@ impl, const D: usize> CircuitBuilder { pub fn neg_nonnative( &mut self, - x: &ForeignFieldTarget, - ) -> ForeignFieldTarget { + x: &NonNativeTarget, + ) -> NonNativeTarget { let neg_one = FF::order() - BigUint::one(); let neg_one_target = self.constant_biguint(&neg_one); let neg_one_ff = self.biguint_to_nonnative(&neg_one_target); @@ -84,13 +84,13 @@ impl, const D: usize> CircuitBuilder { self.mul_nonnative(&neg_one_ff, x) } - /// Returns `x % |FF|` as a `ForeignFieldTarget`. - fn reduce(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { + /// Returns `x % |FF|` as a `NonNativeTarget`. + fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget { let modulus = FF::order(); let order_target = self.constant_biguint(&modulus); let value = self.rem_biguint(x, &order_target); - ForeignFieldTarget { + NonNativeTarget { value, _phantom: PhantomData, } @@ -99,8 +99,8 @@ impl, const D: usize> CircuitBuilder { #[allow(dead_code)] fn reduce_nonnative( &mut self, - x: &ForeignFieldTarget, - ) -> ForeignFieldTarget { + x: &NonNativeTarget, + ) -> NonNativeTarget { let x_biguint = self.nonnative_to_biguint(x); self.reduce(&x_biguint) } From d9868de6932478a987889c521121566908f565d9 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 19 Nov 2021 15:29:22 -0800 Subject: [PATCH 142/202] merge --- src/plonk/circuit_builder.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 3cb62cf3..6126ab90 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -15,8 +15,9 @@ use crate::gadgets::arithmetic::BaseArithmeticOperation; use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation; use crate::gadgets::arithmetic_u32::U32Target; use crate::gates::arithmetic_base::ArithmeticGate; -use crate::gates::arithmetic_extension::ArithmeticExtensionGate; -use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; +use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::gates::arithmetic_u32::{NUM_U32_ARITHMETIC_OPS, U32ArithmeticGate}; +use crate::gates::subtraction_u32::{NUM_U32_SUBTRACTION_OPS, U32SubtractionGate}; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; From fd2e276405fe3457bb8e027e77375687ef48d025 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 11:47:20 -0800 Subject: [PATCH 143/202] merge --- src/plonk/circuit_builder.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 6126ab90..693abe0a 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -16,8 +16,7 @@ use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation; use crate::gadgets::arithmetic_u32::U32Target; use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::arithmetic::ArithmeticExtensionGate; -use crate::gates::arithmetic_u32::{NUM_U32_ARITHMETIC_OPS, U32ArithmeticGate}; -use crate::gates::subtraction_u32::{NUM_U32_SUBTRACTION_OPS, U32SubtractionGate}; +use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; From f9c9cc83f4df212d78e933e3c3f8a90b7d04d87b Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 19 Oct 2021 15:39:39 -0700 Subject: [PATCH 144/202] fix: run all U32SubtractionGate generators --- src/gadgets/biguint.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index fa5aa0b6..05a5406e 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -368,4 +368,28 @@ mod tests { let proof = data.prove(pw).unwrap(); verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_biguint_sub() -> Result<()> { + let x_value = BigUint::from_u128(33333333333333333333333333333333333333).unwrap(); + let y_value = BigUint::from_u128(22222222222222222222222222222222222222).unwrap(); + let expected_z_value = &x_value - &y_value; + + type F = CrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(x_value); + let y = builder.constant_biguint(y_value); + let z = builder.sub_biguint(x, y); + let expected_z = builder.constant_biguint(expected_z_value); + + builder.connect_biguint(z, expected_z); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } } From ebce0799a2abca0ef0e58d062edc285c527afe18 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 28 Oct 2021 11:48:14 -0700 Subject: [PATCH 145/202] initial curve_types and curve_adds --- src/curve/curve_adds.rs | 129 +++++++++++++++++ src/curve/curve_types.rs | 299 +++++++++++++++++++++++++++++++++++++++ src/curve/mod.rs | 2 + src/field/field_types.rs | 8 ++ src/lib.rs | 1 + 5 files changed, 439 insertions(+) create mode 100644 src/curve/curve_adds.rs create mode 100644 src/curve/curve_types.rs create mode 100644 src/curve/mod.rs diff --git a/src/curve/curve_adds.rs b/src/curve/curve_adds.rs new file mode 100644 index 00000000..c5fc9ba4 --- /dev/null +++ b/src/curve/curve_adds.rs @@ -0,0 +1,129 @@ +use std::ops::Add; + +use crate::field::field_types::Field; +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; + +impl Add> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: ProjectivePoint) -> Self::Output { + let ProjectivePoint { x: x1, y: y1, z: z1, zero: zero1 } = self; + let ProjectivePoint { x: x2, y: y2, z: z2, zero: zero2 } = rhs; + + if zero1 { + return rhs; + } + if zero2 { + return self; + } + + let x1z2 = x1 * z2; + let y1z2 = y1 * z2; + let x2z1 = x2 * z1; + let y2z1 = y2 * z1; + + // Check if we're doubling or adding inverses. + if x1z2 == x2z1 { + if y1z2 == y2z1 { + // TODO: inline to avoid redundant muls. + return self.double(); + } + if y1z2 == -y2z1 { + return ProjectivePoint::ZERO; + } + } + + let z1z2 = z1 * z2; + let u = y2z1 - y1z2; + let uu = u.square(); + let v = x2z1 - x1z2; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1z2; + let a = uu * z1z2 - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1z2; + let z3 = vvv * z1z2; + ProjectivePoint::nonzero(x3, y3, z3) + } +} + +impl Add> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: AffinePoint) -> Self::Output { + let ProjectivePoint { x: x1, y: y1, z: z1, zero: zero1 } = self; + let AffinePoint { x: x2, y: y2, zero: zero2 } = rhs; + + if zero1 { + return rhs.to_projective(); + } + if zero2 { + return self; + } + + let x2z1 = x2 * z1; + let y2z1 = y2 * z1; + + // Check if we're doubling or adding inverses. + if x1 == x2z1 { + if y1 == y2z1 { + // TODO: inline to avoid redundant muls. + return self.double(); + } + if y1 == -y2z1 { + return ProjectivePoint::ZERO; + } + } + + let u = y2z1 - y1; + let uu = u.square(); + let v = x2z1 - x1; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1; + let a = uu * z1 - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1; + let z3 = vvv * z1; + ProjectivePoint::nonzero(x3, y3, z3) + } +} + +impl Add> for AffinePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: AffinePoint) -> Self::Output { + let AffinePoint { x: x1, y: y1, zero: zero1 } = self; + let AffinePoint { x: x2, y: y2, zero: zero2 } = rhs; + + if zero1 { + return rhs.to_projective(); + } + if zero2 { + return self.to_projective(); + } + + // Check if we're doubling or adding inverses. + if x1 == x2 { + if y1 == y2 { + return self.to_projective().double(); + } + if y1 == -y2 { + return ProjectivePoint::ZERO; + } + } + + let u = y2 - y1; + let uu = u.square(); + let v = x2 - x1; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1; + let a = uu - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1; + let z3 = vvv; + ProjectivePoint::nonzero(x3, y3, z3) + } +} diff --git a/src/curve/curve_types.rs b/src/curve/curve_types.rs new file mode 100644 index 00000000..f7d55c0c --- /dev/null +++ b/src/curve/curve_types.rs @@ -0,0 +1,299 @@ +use std::ops::Neg; + +use anyhow::Result; + +use crate::field::field_types::Field; +use std::fmt::Debug; + +// To avoid implementation conflicts from associated types, +// see https://github.com/rust-lang/rust/issues/20400 +pub struct CurveScalar(pub ::ScalarField); + +/// A short Weierstrass curve. +pub trait Curve: 'static + Sync + Sized + Copy + Debug { + type BaseField: Field; + type ScalarField: Field; + + const A: Self::BaseField; + const B: Self::BaseField; + + const GENERATOR_AFFINE: AffinePoint; + + const GENERATOR_PROJECTIVE: ProjectivePoint = ProjectivePoint { + x: Self::GENERATOR_AFFINE.x, + y: Self::GENERATOR_AFFINE.y, + z: Self::BaseField::ONE, + zero: false, + }; + + fn convert(x: Self::ScalarField) -> CurveScalar { + CurveScalar(x) + } + + /*fn try_convert_b2s(x: Self::BaseField) -> Result { + x.try_convert::() + } + + fn try_convert_s2b(x: Self::ScalarField) -> Result { + x.try_convert::() + } + + fn try_convert_s2b_slice(s: &[Self::ScalarField]) -> Result> { + let mut res = Vec::with_capacity(s.len()); + for &x in s { + res.push(Self::try_convert_s2b(x)?); + } + Ok(res) + } + + fn try_convert_b2s_slice(s: &[Self::BaseField]) -> Result> { + let mut res = Vec::with_capacity(s.len()); + for &x in s { + res.push(Self::try_convert_b2s(x)?); + } + Ok(res) + }*/ + + fn is_safe_curve() -> bool{ + // Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0. + (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()).is_nonzero() + } +} + +/// A point on a short Weierstrass curve, represented in affine coordinates. +#[derive(Copy, Clone, Debug)] +pub struct AffinePoint { + pub x: C::BaseField, + pub y: C::BaseField, + pub zero: bool, +} + +impl AffinePoint { + pub const ZERO: Self = Self { + x: C::BaseField::ZERO, + y: C::BaseField::ZERO, + zero: true, + }; + + pub fn nonzero(x: C::BaseField, y: C::BaseField) -> Self { + let point = Self { x, y, zero: false }; + debug_assert!(point.is_valid()); + point + } + + pub fn is_valid(&self) -> bool { + let Self { x, y, zero } = *self; + zero || y.square() == x.cube() + C::A * x + C::B + } + + pub fn to_projective(&self) -> ProjectivePoint { + let Self { x, y, zero } = *self; + ProjectivePoint { + x, + y, + z: C::BaseField::ONE, + zero, + } + } + + pub fn batch_to_projective(affine_points: &[Self]) -> Vec> { + affine_points.iter().map(Self::to_projective).collect() + } + + pub fn double(&self) -> Self { + let AffinePoint { + x: x1, + y: y1, + zero, + } = *self; + + if zero { + return AffinePoint::ZERO; + } + + let double_y = y1.double(); + let inv_double_y = double_y.inverse(); // (2y)^(-1) + let triple_xx = x1.square().triple(); // 3x^2 + let lambda = (triple_xx + C::A) * inv_double_y; + let x3 = lambda.square() - self.x.double(); + let y3 = lambda * (x1 - x3) - y1; + + Self { + x: x3, + y: y3, + zero: false, + } + } + +} + +impl PartialEq for AffinePoint { + fn eq(&self, other: &Self) -> bool { + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = *self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = *other; + if zero1 || zero2 { + return zero1 == zero2; + } + x1 == x2 && y1 == y2 + } +} + +impl Eq for AffinePoint {} + +/// A point on a short Weierstrass curve, represented in projective coordinates. +#[derive(Copy, Clone, Debug)] +pub struct ProjectivePoint { + pub x: C::BaseField, + pub y: C::BaseField, + pub z: C::BaseField, + pub zero: bool, +} + +impl ProjectivePoint { + pub const ZERO: Self = Self { + x: C::BaseField::ZERO, + y: C::BaseField::ZERO, + z: C::BaseField::ZERO, + zero: true, + }; + + pub fn nonzero(x: C::BaseField, y: C::BaseField, z: C::BaseField) -> Self { + let point = Self { + x, + y, + z, + zero: false, + }; + debug_assert!(point.is_valid()); + point + } + + pub fn is_valid(&self) -> bool { + self.to_affine().is_valid() + } + + pub fn to_affine(&self) -> AffinePoint { + let Self { x, y, z, zero } = *self; + if zero { + AffinePoint::ZERO + } else { + let z_inv = z.inverse(); + AffinePoint::nonzero(x * z_inv, y * z_inv) + } + } + + pub fn batch_to_affine(proj_points: &[Self]) -> Vec> { + let n = proj_points.len(); + let zs: Vec = proj_points.iter().map(|pp| pp.z).collect(); + let z_invs = C::BaseField::batch_multiplicative_inverse(&zs); + + let mut result = Vec::with_capacity(n); + for i in 0..n { + let Self { x, y, z: _, zero } = proj_points[i]; + result.push(if zero { + AffinePoint::ZERO + } else { + let z_inv = z_invs[i]; + AffinePoint::nonzero(x * z_inv, y * z_inv) + }); + } + result + } + + pub fn double(&self) -> Self { + let Self { x, y, z, zero } = *self; + if zero { + return ProjectivePoint::ZERO; + } + + let xx = x.square(); + let zz = z.square(); + let mut w = xx.triple(); + if C::A.is_nonzero() { + w = w + C::A * zz; + } + let s = y.double() * z; + let r = y * s; + let rr = r.square(); + let b = (x + r).square() - (xx + rr); + let h = w.square() - b.double(); + let x3 = h * s; + let y3 = w * (b - h) - rr.double(); + let z3 = s.cube(); + Self { + x: x3, + y: y3, + z: z3, + zero: false, + } + } + + pub fn add_slices(a: &[Self], b: &[Self]) -> Vec { + assert_eq!(a.len(), b.len()); + a.iter() + .zip(b.iter()) + .map(|(&a_i, &b_i)| a_i + b_i) + .collect() + } + + pub fn neg(&self) -> Self { + Self { + x: self.x, + y: -self.y, + z: self.z, + zero: self.zero, + } + } +} + +impl PartialEq for ProjectivePoint { + fn eq(&self, other: &Self) -> bool { + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + zero: zero1, + } = *self; + let ProjectivePoint { + x: x2, + y: y2, + z: z2, + zero: zero2, + } = *other; + if zero1 || zero2 { + return zero1 == zero2; + } + + // We want to compare (x1/z1, y1/z1) == (x2/z2, y2/z2). + // But to avoid field division, it is better to compare (x1*z2, y1*z2) == (x2*z1, y2*z1). + x1 * z2 == x2 * z1 && y1 * z2 == y2 * z1 + } +} + +impl Eq for ProjectivePoint {} + +impl Neg for AffinePoint { + type Output = AffinePoint; + + fn neg(self) -> Self::Output { + let AffinePoint { x, y, zero } = self; + AffinePoint { x, y: -y, zero } + } +} + +impl Neg for ProjectivePoint { + type Output = ProjectivePoint; + + fn neg(self) -> Self::Output { + let ProjectivePoint { x, y, z, zero } = self; + ProjectivePoint { x, y: -y, z, zero } + } +} diff --git a/src/curve/mod.rs b/src/curve/mod.rs new file mode 100644 index 00000000..1e536564 --- /dev/null +++ b/src/curve/mod.rs @@ -0,0 +1,2 @@ +pub mod curve_adds; +pub mod curve_types; diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 036793cc..250338fb 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -91,6 +91,14 @@ pub trait Field: self.square() * *self } + fn double(&self) -> Self { + *self * Self::TWO + } + + fn triple(&self) -> Self { + *self * (Self::ONE + Self::TWO) + } + /// Compute the multiplicative inverse of this field element. fn try_inverse(&self) -> Option; diff --git a/src/lib.rs b/src/lib.rs index e76e312c..b0158d7a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ #![feature(specialization)] #![feature(stdsimd)] +pub mod curve; pub mod field; pub mod fri; pub mod gadgets; From db464f739e3cac1ff3eac054dd90dbfe9d43b3fc Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 11:49:30 -0800 Subject: [PATCH 146/202] merge --- src/curve/curve_adds.rs | 41 ++++- src/curve/curve_multiplication.rs | 89 ++++++++++ src/curve/curve_summation.rs | 236 +++++++++++++++++++++++++ src/curve/curve_types.rs | 14 +- src/curve/mod.rs | 2 + src/field/extension_field/quadratic.rs | 2 + src/field/extension_field/quartic.rs | 2 + src/field/field_types.rs | 2 + src/field/goldilocks_field.rs | 2 + src/field/secp256k1.rs | 2 + 10 files changed, 376 insertions(+), 16 deletions(-) create mode 100644 src/curve/curve_multiplication.rs create mode 100644 src/curve/curve_summation.rs diff --git a/src/curve/curve_adds.rs b/src/curve/curve_adds.rs index c5fc9ba4..32a5adcc 100644 --- a/src/curve/curve_adds.rs +++ b/src/curve/curve_adds.rs @@ -1,14 +1,24 @@ use std::ops::Add; -use crate::field::field_types::Field; use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; +use crate::field::field_types::Field; impl Add> for ProjectivePoint { type Output = ProjectivePoint; fn add(self, rhs: ProjectivePoint) -> Self::Output { - let ProjectivePoint { x: x1, y: y1, z: z1, zero: zero1 } = self; - let ProjectivePoint { x: x2, y: y2, z: z2, zero: zero2 } = rhs; + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + zero: zero1, + } = self; + let ProjectivePoint { + x: x2, + y: y2, + z: z2, + zero: zero2, + } = rhs; if zero1 { return rhs; @@ -52,8 +62,17 @@ impl Add> for ProjectivePoint { type Output = ProjectivePoint; fn add(self, rhs: AffinePoint) -> Self::Output { - let ProjectivePoint { x: x1, y: y1, z: z1, zero: zero1 } = self; - let AffinePoint { x: x2, y: y2, zero: zero2 } = rhs; + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + zero: zero1, + } = self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = rhs; if zero1 { return rhs.to_projective(); @@ -94,8 +113,16 @@ impl Add> for AffinePoint { type Output = ProjectivePoint; fn add(self, rhs: AffinePoint) -> Self::Output { - let AffinePoint { x: x1, y: y1, zero: zero1 } = self; - let AffinePoint { x: x2, y: y2, zero: zero2 } = rhs; + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = rhs; if zero1 { return rhs.to_projective(); diff --git a/src/curve/curve_multiplication.rs b/src/curve/curve_multiplication.rs new file mode 100644 index 00000000..e5ac0eb3 --- /dev/null +++ b/src/curve/curve_multiplication.rs @@ -0,0 +1,89 @@ +use std::ops::Mul; + +use crate::curve::curve_summation::affine_summation_batch_inversion; +use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar, ProjectivePoint}; +use crate::field::field_types::Field; + +const WINDOW_BITS: usize = 4; +const BASE: usize = 1 << WINDOW_BITS; + +fn digits_per_scalar() -> usize { + (C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS +} + +/// Precomputed state used for scalar x ProjectivePoint multiplications, +/// specific to a particular generator. +#[derive(Clone)] +pub struct MultiplicationPrecomputation { + /// [(2^w)^i] g for each i < digits_per_scalar. + powers: Vec>, +} + +impl ProjectivePoint { + pub fn mul_precompute(&self) -> MultiplicationPrecomputation { + let num_digits = digits_per_scalar::(); + let mut powers_proj = Vec::with_capacity(num_digits); + powers_proj.push(*self); + for i in 1..num_digits { + let mut power_i_proj = powers_proj[i - 1]; + for _j in 0..WINDOW_BITS { + power_i_proj = power_i_proj.double(); + } + powers_proj.push(power_i_proj); + } + + let powers = ProjectivePoint::batch_to_affine(&powers_proj); + MultiplicationPrecomputation { powers } + } + + pub fn mul_with_precomputation( + &self, + scalar: C::ScalarField, + precomputation: MultiplicationPrecomputation, + ) -> Self { + // Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf + let precomputed_powers = precomputation.powers; + + let digits = to_digits::(&scalar); + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + for j in (1..BASE).rev() { + let mut u_summands = Vec::new(); + for (i, &digit) in digits.iter().enumerate() { + if digit == j as u64 { + u_summands.push(precomputed_powers[i]); + } + } + u = u + affine_summation_batch_inversion(u_summands); + y = y + u; + } + y + } +} + +impl Mul> for CurveScalar { + type Output = ProjectivePoint; + + fn mul(self, rhs: ProjectivePoint) -> Self::Output { + let precomputation = rhs.mul_precompute(); + rhs.mul_with_precomputation(self.0, precomputation) + } +} + +#[allow(clippy::assertions_on_constants)] +fn to_digits(x: &C::ScalarField) -> Vec { + debug_assert!( + 64 % WINDOW_BITS == 0, + "For simplicity, only power-of-two window sizes are handled for now" + ); + let digits_per_u64 = 64 / WINDOW_BITS; + let mut digits = Vec::with_capacity(digits_per_scalar::()); + for limb in x.to_biguint().to_u64_digits() { + for j in 0..digits_per_u64 { + digits.push((limb >> (j * WINDOW_BITS) as u64) % BASE as u64); + } + } + + digits +} diff --git a/src/curve/curve_summation.rs b/src/curve/curve_summation.rs new file mode 100644 index 00000000..501a4977 --- /dev/null +++ b/src/curve/curve_summation.rs @@ -0,0 +1,236 @@ +use std::iter::Sum; + +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; +use crate::field::field_types::Field; + +impl Sum> for ProjectivePoint { + fn sum>>(iter: I) -> ProjectivePoint { + let points: Vec<_> = iter.collect(); + affine_summation_best(points) + } +} + +impl Sum for ProjectivePoint { + fn sum>>(iter: I) -> ProjectivePoint { + iter.fold(ProjectivePoint::ZERO, |acc, x| acc + x) + } +} + +pub fn affine_summation_best(summation: Vec>) -> ProjectivePoint { + let result = affine_multisummation_best(vec![summation]); + debug_assert_eq!(result.len(), 1); + result[0] +} + +pub fn affine_multisummation_best( + summations: Vec>>, +) -> Vec> { + let pairwise_sums: usize = summations.iter().map(|summation| summation.len() / 2).sum(); + + // This threshold is chosen based on data from the summation benchmarks. + if pairwise_sums < 70 { + affine_multisummation_pairwise(summations) + } else { + affine_multisummation_batch_inversion(summations) + } +} + +/// Adds each pair of points using an affine + affine = projective formula, then adds up the +/// intermediate sums using a projective formula. +pub fn affine_multisummation_pairwise( + summations: Vec>>, +) -> Vec> { + summations + .into_iter() + .map(affine_summation_pairwise) + .collect() +} + +/// Adds each pair of points using an affine + affine = projective formula, then adds up the +/// intermediate sums using a projective formula. +pub fn affine_summation_pairwise(points: Vec>) -> ProjectivePoint { + let mut reduced_points: Vec> = Vec::new(); + for chunk in points.chunks(2) { + match chunk.len() { + 1 => reduced_points.push(chunk[0].to_projective()), + 2 => reduced_points.push(chunk[0] + chunk[1]), + _ => panic!(), + } + } + // TODO: Avoid copying (deref) + reduced_points + .iter() + .fold(ProjectivePoint::ZERO, |sum, x| sum + *x) +} + +/// Computes several summations of affine points by applying an affine group law, except that the +/// divisions are batched via Montgomery's trick. +pub fn affine_summation_batch_inversion( + summation: Vec>, +) -> ProjectivePoint { + let result = affine_multisummation_batch_inversion(vec![summation]); + debug_assert_eq!(result.len(), 1); + result[0] +} + +/// Computes several summations of affine points by applying an affine group law, except that the +/// divisions are batched via Montgomery's trick. +pub fn affine_multisummation_batch_inversion( + summations: Vec>>, +) -> Vec> { + let mut elements_to_invert = Vec::new(); + + // For each pair of points, (x1, y1) and (x2, y2), that we're going to add later, we want to + // invert either y (if the points are equal) or x1 - x2 (otherwise). We will use these later. + for summation in &summations { + let n = summation.len(); + // The special case for n=0 is to avoid underflow. + let range_end = if n == 0 { 0 } else { n - 1 }; + + for i in (0..range_end).step_by(2) { + let p1 = summation[i]; + let p2 = summation[i + 1]; + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = p1; + let AffinePoint { + x: x2, + y: _y2, + zero: zero2, + } = p2; + + if zero1 || zero2 || p1 == -p2 { + // These are trivial cases where we won't need any inverse. + } else if p1 == p2 { + elements_to_invert.push(y1.double()); + } else { + elements_to_invert.push(x1 - x2); + } + } + } + + let inverses: Vec = + C::BaseField::batch_multiplicative_inverse(&elements_to_invert); + + let mut all_reduced_points = Vec::with_capacity(summations.len()); + let mut inverse_index = 0; + for summation in summations { + let n = summation.len(); + let mut reduced_points = Vec::with_capacity((n + 1) / 2); + + // The special case for n=0 is to avoid underflow. + let range_end = if n == 0 { 0 } else { n - 1 }; + + for i in (0..range_end).step_by(2) { + let p1 = summation[i]; + let p2 = summation[i + 1]; + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = p1; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = p2; + + let sum = if zero1 { + p2 + } else if zero2 { + p1 + } else if p1 == -p2 { + AffinePoint::ZERO + } else { + // It's a non-trivial case where we need one of the inverses we computed earlier. + let inverse = inverses[inverse_index]; + inverse_index += 1; + + if p1 == p2 { + // This is the doubling case. + let mut numerator = x1.square().triple(); + if C::A.is_nonzero() { + numerator = numerator + C::A; + } + let quotient = numerator * inverse; + let x3 = quotient.square() - x1.double(); + let y3 = quotient * (x1 - x3) - y1; + AffinePoint::nonzero(x3, y3) + } else { + // This is the general case. We use the incomplete addition formulas 4.3 and 4.4. + let quotient = (y1 - y2) * inverse; + let x3 = quotient.square() - x1 - x2; + let y3 = quotient * (x1 - x3) - y1; + AffinePoint::nonzero(x3, y3) + } + }; + reduced_points.push(sum); + } + + // If n is odd, the last point was not part of a pair. + if n % 2 == 1 { + reduced_points.push(summation[n - 1]); + } + + all_reduced_points.push(reduced_points); + } + + // We should have consumed all of the inverses from the batch computation. + debug_assert_eq!(inverse_index, inverses.len()); + + // Recurse with our smaller set of points. + affine_multisummation_best(all_reduced_points) +} + +#[cfg(test)] +mod tests { + use crate::{ + affine_summation_batch_inversion, affine_summation_pairwise, Bls12377, Curve, + ProjectivePoint, + }; + + #[test] + fn test_pairwise_affine_summation() { + let g_affine = Bls12377::GENERATOR_AFFINE; + let g2_affine = (g_affine + g_affine).to_affine(); + let g3_affine = (g_affine + g_affine + g_affine).to_affine(); + let g2_proj = g2_affine.to_projective(); + let g3_proj = g3_affine.to_projective(); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g_affine]), + g2_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g2_affine]), + g3_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g_affine, g_affine]), + g3_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![]), + ProjectivePoint::ZERO + ); + } + + #[test] + fn test_pairwise_affine_summation_batch_inversion() { + let g = Bls12377::GENERATOR_AFFINE; + let g_proj = g.to_projective(); + assert_eq!( + affine_summation_batch_inversion::(vec![g, g]), + g_proj + g_proj + ); + assert_eq!( + affine_summation_batch_inversion::(vec![g, g, g]), + g_proj + g_proj + g_proj + ); + assert_eq!( + affine_summation_batch_inversion::(vec![]), + ProjectivePoint::ZERO + ); + } +} diff --git a/src/curve/curve_types.rs b/src/curve/curve_types.rs index f7d55c0c..830dc7c1 100644 --- a/src/curve/curve_types.rs +++ b/src/curve/curve_types.rs @@ -1,9 +1,9 @@ +use std::fmt::Debug; use std::ops::Neg; use anyhow::Result; use crate::field::field_types::Field; -use std::fmt::Debug; // To avoid implementation conflicts from associated types, // see https://github.com/rust-lang/rust/issues/20400 @@ -54,9 +54,10 @@ pub trait Curve: 'static + Sync + Sized + Copy + Debug { Ok(res) }*/ - fn is_safe_curve() -> bool{ + fn is_safe_curve() -> bool { // Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0. - (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()).is_nonzero() + (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()) + .is_nonzero() } } @@ -101,11 +102,7 @@ impl AffinePoint { } pub fn double(&self) -> Self { - let AffinePoint { - x: x1, - y: y1, - zero, - } = *self; + let AffinePoint { x: x1, y: y1, zero } = *self; if zero { return AffinePoint::ZERO; @@ -124,7 +121,6 @@ impl AffinePoint { zero: false, } } - } impl PartialEq for AffinePoint { diff --git a/src/curve/mod.rs b/src/curve/mod.rs index 1e536564..8b9df88e 100644 --- a/src/curve/mod.rs +++ b/src/curve/mod.rs @@ -1,2 +1,4 @@ pub mod curve_adds; +pub mod curve_multiplication; +pub mod curve_summation; pub mod curve_types; diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index e2794330..b724095a 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -67,6 +67,8 @@ impl> Field for QuadraticExtension { const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); + const BITS: usize = F::BITS * 2; + fn order() -> BigUint { F::order() * F::order() } diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index 01918ff3..0d221401 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -69,6 +69,8 @@ impl> Field for QuarticExtension { const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); + const BITS: usize = F::BITS * 4; + fn order() -> BigUint { F::order().pow(4u32) } diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 250338fb..80eeecff 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -59,6 +59,8 @@ pub trait Field: /// Generator of a multiplicative subgroup of order `2^TWO_ADICITY`. const POWER_OF_TWO_GENERATOR: Self; + const BITS: usize; + fn order() -> BigUint; #[inline] diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index cb85d56d..058b6db8 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -82,6 +82,8 @@ impl Field for GoldilocksField { // ``` const POWER_OF_TWO_GENERATOR: Self = Self(1753635133440165772); + const BITS: usize = 64; + fn order() -> BigUint { Self::ORDER.into() } diff --git a/src/field/secp256k1.rs b/src/field/secp256k1.rs index 75221a1f..acb1df4e 100644 --- a/src/field/secp256k1.rs +++ b/src/field/secp256k1.rs @@ -91,6 +91,8 @@ impl Field for Secp256K1Base { // Sage: `g_2 = g^((p - 1) / 2)` const POWER_OF_TWO_GENERATOR: Self = Self::NEG_ONE; + const BITS: usize = 256; + fn order() -> BigUint { BigUint::from_slice(&[ 0xFFFFFC2F, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, From 869a5860f424268b3b7368d98944acb1b638d003 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 28 Oct 2021 19:54:39 -0700 Subject: [PATCH 147/202] Secp256K1 scalar field --- src/curve/mod.rs | 1 + src/curve/secp256k1_curve.rs | 81 ++++++ src/field/mod.rs | 3 +- src/field/{secp256k1.rs => secp256k1_base.rs} | 2 +- src/field/secp256k1_scalar.rs | 253 ++++++++++++++++++ 5 files changed, 338 insertions(+), 2 deletions(-) create mode 100644 src/curve/secp256k1_curve.rs rename src/field/{secp256k1.rs => secp256k1_base.rs} (99%) create mode 100644 src/field/secp256k1_scalar.rs diff --git a/src/curve/mod.rs b/src/curve/mod.rs index 8b9df88e..01841018 100644 --- a/src/curve/mod.rs +++ b/src/curve/mod.rs @@ -2,3 +2,4 @@ pub mod curve_adds; pub mod curve_multiplication; pub mod curve_summation; pub mod curve_types; +//pub mod secp256k1_curve; \ No newline at end of file diff --git a/src/curve/secp256k1_curve.rs b/src/curve/secp256k1_curve.rs new file mode 100644 index 00000000..78ce993e --- /dev/null +++ b/src/curve/secp256k1_curve.rs @@ -0,0 +1,81 @@ +use crate::curve::curve_types::{AffinePoint, Curve}; +use crate::field::field_types::Field; +use crate::field::secp256k1_base::Secp256K1Base; +use crate::field::secp256k1_scalar::Secp256K1Scalar; + +// Parameters taken from the implementation of Bls12-377 in Zexe found here: +// https://github.com/scipr-lab/zexe/blob/master/algebra/src/curves/bls12_377/g1.rs + +#[derive(Debug, Copy, Clone)] +pub struct Secp256K1; + +impl Curve for Bls12377 { + type BaseField = Bls12377Base; + type ScalarField = Bls12377Scalar; + + const A: Bls12377Base = Bls12377Base::ZERO; + const B: Bls12377Base = Bls12377Base::ONE; + const GENERATOR_AFFINE: AffinePoint = AffinePoint { + x: BLS12_377_GENERATOR_X, + y: BLS12_377_GENERATOR_Y, + zero: false, + }; +} + +/// 81937999373150964239938255573465948239988671502647976594219695644855304257327692006745978603320413799295628339695 +const BLS12_377_GENERATOR_X: Bls12377Base = Bls12377Base { + limbs: [2742467569752756724, 14217256487979144792, 6635299530028159197, 8509097278468658840, + 14518893593143693938, 46181716169194829] +}; + +/// 241266749859715473739788878240585681733927191168601896383759122102112907357779751001206799952863815012735208165030 +const BLS12_377_GENERATOR_Y: Bls12377Base = Bls12377Base { + limbs: [9336971515457667571, 28021381849722296, 18085035374859187530, 14013031479170682136, + 3369780711397861396, 35370409237953649] +}; + +#[cfg(test)] +mod tests { + use crate::{blake_hash_usize_to_curve, Bls12377, Bls12377Scalar, Curve, Field, ProjectivePoint}; + + #[test] + fn test_double_affine() { + for i in 0..100 { + let p = blake_hash_usize_to_curve::(i); + assert_eq!( + p.double(), + p.to_projective().double().to_affine()); + } + } + + #[test] + fn test_naive_multiplication() { + let g = Bls12377::GENERATOR_PROJECTIVE; + let ten = Bls12377Scalar::from_canonical_u64(10); + let product = mul_naive(ten, g); + let sum = g + g + g + g + g + g + g + g + g + g; + assert_eq!(product, sum); + } + + #[test] + fn test_g1_multiplication() { + let lhs = Bls12377Scalar::from_canonical([11111111, 22222222, 33333333, 44444444]); + assert_eq!(Bls12377::convert(lhs) * Bls12377::GENERATOR_PROJECTIVE, mul_naive(lhs, Bls12377::GENERATOR_PROJECTIVE)); + } + + /// A simple, somewhat inefficient implementation of multiplication which is used as a reference + /// for correctness. + fn mul_naive(lhs: Bls12377Scalar, rhs: ProjectivePoint) -> ProjectivePoint { + let mut g = rhs; + let mut sum = ProjectivePoint::ZERO; + for limb in lhs.to_canonical().iter() { + for j in 0..64 { + if (limb >> j & 1u64) != 0u64 { + sum = sum + g; + } + g = g.double(); + } + } + sum + } +} diff --git a/src/field/mod.rs b/src/field/mod.rs index 5ed64a54..74e0fbf4 100644 --- a/src/field/mod.rs +++ b/src/field/mod.rs @@ -7,7 +7,8 @@ pub(crate) mod interpolation; mod inversion; pub(crate) mod packable; pub(crate) mod packed_field; -pub mod secp256k1; +pub mod secp256k1_base; +pub mod secp256k1_scalar; #[cfg(target_feature = "avx2")] pub(crate) mod packed_avx2; diff --git a/src/field/secp256k1.rs b/src/field/secp256k1_base.rs similarity index 99% rename from src/field/secp256k1.rs rename to src/field/secp256k1_base.rs index acb1df4e..a09edc30 100644 --- a/src/field/secp256k1.rs +++ b/src/field/secp256k1_base.rs @@ -88,7 +88,7 @@ impl Field for Secp256K1Base { // Sage: `g = GF(p).multiplicative_generator()` const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]); - // Sage: `g_2 = g^((p - 1) / 2)` + // Sage: `g_2 = power_mod(g, (p - 1) // 2), p)` const POWER_OF_TWO_GENERATOR: Self = Self::NEG_ONE; const BITS: usize = 256; diff --git a/src/field/secp256k1_scalar.rs b/src/field/secp256k1_scalar.rs new file mode 100644 index 00000000..4423f726 --- /dev/null +++ b/src/field/secp256k1_scalar.rs @@ -0,0 +1,253 @@ +use std::convert::TryInto; +use std::fmt; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::iter::{Product, Sum}; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use itertools::Itertools; +use num::bigint::{BigUint, RandBigInt}; +use num::{Integer, One}; +use rand::Rng; +use serde::{Deserialize, Serialize}; + +use crate::field::field_types::Field; +use crate::field::goldilocks_field::GoldilocksField; + +/// The base field of the secp256k1 elliptic curve. +/// +/// Its order is +/// ```ignore +/// P = 0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141 +/// = 115792089237316195423570985008687907852837564279074904382605163141518161494337 +/// = 2**256 - 432420386565659656852420866394968145599 +/// ``` +#[derive(Copy, Clone, Serialize, Deserialize)] +pub struct Secp256K1Scalar(pub [u64; 4]); + +fn biguint_from_array(arr: [u64; 4]) -> BigUint { + BigUint::from_slice(&[ + arr[0] as u32, + (arr[0] >> 32) as u32, + arr[1] as u32, + (arr[1] >> 32) as u32, + arr[2] as u32, + (arr[2] >> 32) as u32, + arr[3] as u32, + (arr[3] >> 32) as u32, + ]) +} + +impl Default for Secp256K1Scalar { + fn default() -> Self { + Self::ZERO + } +} + +impl PartialEq for Secp256K1Scalar { + fn eq(&self, other: &Self) -> bool { + self.to_biguint() == other.to_biguint() + } +} + +impl Eq for Secp256K1Scalar {} + +impl Hash for Secp256K1Scalar { + fn hash(&self, state: &mut H) { + self.to_biguint().hash(state) + } +} + +impl Display for Secp256K1Scalar { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.to_biguint(), f) + } +} + +impl Debug for Secp256K1Scalar { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.to_biguint(), f) + } +} + +impl Field for Secp256K1Scalar { + // TODO: fix + type PrimeField = GoldilocksField; + + const ZERO: Self = Self([0; 4]); + const ONE: Self = Self([1, 0, 0, 0]); + const TWO: Self = Self([2, 0, 0, 0]); + const NEG_ONE: Self = Self([ + 0xBFD25E8CD0364140, + 0xBAAEDCE6AF48A03B, + 0xFFFFFFFFFFFFFC2F, + 0xFFFFFFFFFFFFFFFF + ]); + + // TODO: fix + const CHARACTERISTIC: u64 = 0; + + const TWO_ADICITY: usize = 6; + + // Sage: `g = GF(p).multiplicative_generator()` + const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([7, 0, 0, 0]); + + // Sage: `g_2 = power_mod(g, (p - 1) // 2^6), p)` + // 5480320495727936603795231718619559942670027629901634955707709633242980176626 + const POWER_OF_TWO_GENERATOR: Self = Self([ + 0x992f4b5402b052f2, + 0x98BDEAB680756045, + 0xDF9879A3FBC483A8, + 0xC1DC060E7A91986, + ]); + + const BITS: usize = 256; + + fn order() -> BigUint { + BigUint::from_slice(&[ + 0xD0364141, 0xBFD25E8C, 0xAF48A03B, 0xBAAEDCE6, 0xFFFFFC2F, 0xFFFFFFFF, 0xFFFFFFFF, + 0xFFFFFFFF + ]) + } + + fn try_inverse(&self) -> Option { + if self.is_zero() { + return None; + } + + // Fermat's Little Theorem + Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) + } + + fn to_biguint(&self) -> BigUint { + let mut result = biguint_from_array(self.0); + if result >= Self::order() { + result -= Self::order(); + } + result + } + + fn from_biguint(val: BigUint) -> Self { + Self( + val.to_u64_digits() + .into_iter() + .pad_using(4, |_| 0) + .collect::>()[..] + .try_into() + .expect("error converting to u64 array"), + ) + } + + #[inline] + fn from_canonical_u64(n: u64) -> Self { + Self([n, 0, 0, 0]) + } + + #[inline] + fn from_noncanonical_u128(n: u128) -> Self { + Self([n as u64, (n >> 64) as u64, 0, 0]) + } + + #[inline] + fn from_noncanonical_u96(n: (u64, u32)) -> Self { + Self([n.0, n.1 as u64, 0, 0]) + } + + fn rand_from_rng(rng: &mut R) -> Self { + Self::from_biguint(rng.gen_biguint_below(&Self::order())) + } +} + +impl Neg for Secp256K1Scalar { + type Output = Self; + + #[inline] + fn neg(self) -> Self { + if self.is_zero() { + Self::ZERO + } else { + Self::from_biguint(Self::order() - self.to_biguint()) + } + } +} + +impl Add for Secp256K1Scalar { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self { + let mut result = self.to_biguint() + rhs.to_biguint(); + if result >= Self::order() { + result -= Self::order(); + } + Self::from_biguint(result) + } +} + +impl AddAssign for Secp256K1Scalar { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl Sum for Secp256K1Scalar { + fn sum>(iter: I) -> Self { + iter.fold(Self::ZERO, |acc, x| acc + x) + } +} + +impl Sub for Secp256K1Scalar { + type Output = Self; + + #[inline] + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, rhs: Self) -> Self { + self + -rhs + } +} + +impl SubAssign for Secp256K1Scalar { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl Mul for Secp256K1Scalar { + type Output = Self; + + #[inline] + fn mul(self, rhs: Self) -> Self { + Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order())) + } +} + +impl MulAssign for Secp256K1Scalar { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl Product for Secp256K1Scalar { + #[inline] + fn product>(iter: I) -> Self { + iter.reduce(|acc, x| acc * x).unwrap_or(Self::ONE) + } +} + +impl Div for Secp256K1Scalar { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: Self) -> Self::Output { + self * rhs.inverse() + } +} + +impl DivAssign for Secp256K1Scalar { + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } +} From 50db118718ac8bd7fa2f05e417cdb1786593444a Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 28 Oct 2021 20:10:19 -0700 Subject: [PATCH 148/202] Secp256K1 curve (in progress) --- src/curve/mod.rs | 2 +- src/curve/secp256k1_curve.rs | 56 ++++++++++++++++++++---------------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/src/curve/mod.rs b/src/curve/mod.rs index 01841018..e1bcb291 100644 --- a/src/curve/mod.rs +++ b/src/curve/mod.rs @@ -2,4 +2,4 @@ pub mod curve_adds; pub mod curve_multiplication; pub mod curve_summation; pub mod curve_types; -//pub mod secp256k1_curve; \ No newline at end of file +pub mod secp256k1_curve; \ No newline at end of file diff --git a/src/curve/secp256k1_curve.rs b/src/curve/secp256k1_curve.rs index 78ce993e..4e74a5f7 100644 --- a/src/curve/secp256k1_curve.rs +++ b/src/curve/secp256k1_curve.rs @@ -9,39 +9,45 @@ use crate::field::secp256k1_scalar::Secp256K1Scalar; #[derive(Debug, Copy, Clone)] pub struct Secp256K1; -impl Curve for Bls12377 { - type BaseField = Bls12377Base; - type ScalarField = Bls12377Scalar; +impl Curve for Secp256K1 { + type BaseField = Secp256K1Base; + type ScalarField = Secp256K1Scalar; - const A: Bls12377Base = Bls12377Base::ZERO; - const B: Bls12377Base = Bls12377Base::ONE; + const A: Secp256K1Base = Secp256K1Base::ZERO; + const B: Secp256K1Base = Secp256K1Base::ONE; const GENERATOR_AFFINE: AffinePoint = AffinePoint { - x: BLS12_377_GENERATOR_X, - y: BLS12_377_GENERATOR_Y, + x: SECP256K1_GENERATOR_X, + y: SECP256K1_GENERATOR_Y, zero: false, }; } -/// 81937999373150964239938255573465948239988671502647976594219695644855304257327692006745978603320413799295628339695 -const BLS12_377_GENERATOR_X: Bls12377Base = Bls12377Base { - limbs: [2742467569752756724, 14217256487979144792, 6635299530028159197, 8509097278468658840, - 14518893593143693938, 46181716169194829] -}; +const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([ + 0x59F2815B16F81798, + 0x029BFCDB2DCE28D9, + 0x55A06295CE870B07, + 0x79BE667EF9DCBBAC, +]); /// 241266749859715473739788878240585681733927191168601896383759122102112907357779751001206799952863815012735208165030 -const BLS12_377_GENERATOR_Y: Bls12377Base = Bls12377Base { - limbs: [9336971515457667571, 28021381849722296, 18085035374859187530, 14013031479170682136, - 3369780711397861396, 35370409237953649] -}; +const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([ + 0x9C47D08FFB10D4B8, + 0xFD17B448A6855419, + 0x5DA4FBFC0E1108A8, + 0x483ADA7726A3C465, +]); #[cfg(test)] mod tests { - use crate::{blake_hash_usize_to_curve, Bls12377, Bls12377Scalar, Curve, Field, ProjectivePoint}; + use crate::field::field_types::Field; + use crate::field::secp256k1_scalar::Secp256K1Scalar; + use crate::curve::curve_types::{Curve, ProjectivePoint}; + use crate::curve::secp256k1_curve::Secp256K1; - #[test] + /*#[test] fn test_double_affine() { for i in 0..100 { - let p = blake_hash_usize_to_curve::(i); + //let p = blake_hash_usize_to_curve::(i); assert_eq!( p.double(), p.to_projective().double().to_affine()); @@ -50,8 +56,8 @@ mod tests { #[test] fn test_naive_multiplication() { - let g = Bls12377::GENERATOR_PROJECTIVE; - let ten = Bls12377Scalar::from_canonical_u64(10); + let g = Secp256K1::GENERATOR_PROJECTIVE; + let ten = Secp256K1Scalar::from_canonical_u64(10); let product = mul_naive(ten, g); let sum = g + g + g + g + g + g + g + g + g + g; assert_eq!(product, sum); @@ -59,13 +65,13 @@ mod tests { #[test] fn test_g1_multiplication() { - let lhs = Bls12377Scalar::from_canonical([11111111, 22222222, 33333333, 44444444]); - assert_eq!(Bls12377::convert(lhs) * Bls12377::GENERATOR_PROJECTIVE, mul_naive(lhs, Bls12377::GENERATOR_PROJECTIVE)); + let lhs = Secp256K1Scalar::from_canonical([11111111, 22222222, 33333333, 44444444]); + assert_eq!(Secp256K1::convert(lhs) * Secp256K1::GENERATOR_PROJECTIVE, mul_naive(lhs, Secp256K1::GENERATOR_PROJECTIVE)); } /// A simple, somewhat inefficient implementation of multiplication which is used as a reference /// for correctness. - fn mul_naive(lhs: Bls12377Scalar, rhs: ProjectivePoint) -> ProjectivePoint { + fn mul_naive(lhs: Secp256K1Scalar, rhs: ProjectivePoint) -> ProjectivePoint { let mut g = rhs; let mut sum = ProjectivePoint::ZERO; for limb in lhs.to_canonical().iter() { @@ -77,5 +83,5 @@ mod tests { } } sum - } + }*/ } From 2c2d36a6be23fed25af854259c78e2c1b11f039f Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 11:49:50 -0800 Subject: [PATCH 149/202] merge --- src/curve/mod.rs | 2 +- src/curve/secp256k1_curve.rs | 6 +++--- src/field/secp256k1_scalar.rs | 4 ++-- src/gadgets/mod.rs | 1 + src/gadgets/secp256k1.rs | 32 ++++++++++++++++++++++++++++++++ 5 files changed, 39 insertions(+), 6 deletions(-) create mode 100644 src/gadgets/secp256k1.rs diff --git a/src/curve/mod.rs b/src/curve/mod.rs index e1bcb291..c65f2acd 100644 --- a/src/curve/mod.rs +++ b/src/curve/mod.rs @@ -2,4 +2,4 @@ pub mod curve_adds; pub mod curve_multiplication; pub mod curve_summation; pub mod curve_types; -pub mod secp256k1_curve; \ No newline at end of file +pub mod secp256k1_curve; diff --git a/src/curve/secp256k1_curve.rs b/src/curve/secp256k1_curve.rs index 4e74a5f7..21340c64 100644 --- a/src/curve/secp256k1_curve.rs +++ b/src/curve/secp256k1_curve.rs @@ -22,7 +22,7 @@ impl Curve for Secp256K1 { }; } -const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([ +const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([ 0x59F2815B16F81798, 0x029BFCDB2DCE28D9, 0x55A06295CE870B07, @@ -39,10 +39,10 @@ const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([ #[cfg(test)] mod tests { - use crate::field::field_types::Field; - use crate::field::secp256k1_scalar::Secp256K1Scalar; use crate::curve::curve_types::{Curve, ProjectivePoint}; use crate::curve::secp256k1_curve::Secp256K1; + use crate::field::field_types::Field; + use crate::field::secp256k1_scalar::Secp256K1Scalar; /*#[test] fn test_double_affine() { diff --git a/src/field/secp256k1_scalar.rs b/src/field/secp256k1_scalar.rs index 4423f726..0c406b86 100644 --- a/src/field/secp256k1_scalar.rs +++ b/src/field/secp256k1_scalar.rs @@ -81,7 +81,7 @@ impl Field for Secp256K1Scalar { 0xBFD25E8CD0364140, 0xBAAEDCE6AF48A03B, 0xFFFFFFFFFFFFFC2F, - 0xFFFFFFFFFFFFFFFF + 0xFFFFFFFFFFFFFFFF, ]); // TODO: fix @@ -106,7 +106,7 @@ impl Field for Secp256K1Scalar { fn order() -> BigUint { BigUint::from_slice(&[ 0xD0364141, 0xBFD25E8C, 0xAF48A03B, 0xBAAEDCE6, 0xFFFFFC2F, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF + 0xFFFFFFFF, ]) } diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 8b6e60f6..42b3044c 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -11,6 +11,7 @@ pub mod permutation; pub mod polynomial; pub mod random_access; pub mod range_check; +pub mod secp256k1; pub mod select; pub mod sorting; pub mod split_base; diff --git a/src/gadgets/secp256k1.rs b/src/gadgets/secp256k1.rs new file mode 100644 index 00000000..36d8d145 --- /dev/null +++ b/src/gadgets/secp256k1.rs @@ -0,0 +1,32 @@ +use crate::curve::curve_types::{AffinePoint, Curve}; +use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; +use crate::gadgets::nonnative::ForeignFieldTarget; +use crate::plonk::circuit_builder::CircuitBuilder; + +#[derive(Clone, Debug)] +pub struct AffinePointTarget { + pub x: ForeignFieldTarget, + pub y: ForeignFieldTarget, +} + +impl AffinePointTarget { + pub fn to_vec(&self) -> Vec> { + vec![self.x.clone(), self.y.clone()] + } +} + +impl, const D: usize> CircuitBuilder { + pub fn constant_affine_point>( + &mut self, + point: AffinePoint, + ) -> AffinePointTarget { + debug_assert!(!point.zero); + AffinePointTarget { + x: self.constant_ff(point.x), + y: self.constant_ff(point.y), + } + } +} + +mod tests {} From 0e1f0c556293b66fd647290e59dc154f5ca5aaa7 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 11:50:04 -0800 Subject: [PATCH 150/202] merge --- src/curve/curve_summation.rs | 25 ++++++++++++------------ src/gadgets/secp256k1.rs | 38 +++++++++++++++++++++++++++++------- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/src/curve/curve_summation.rs b/src/curve/curve_summation.rs index 501a4977..ad4232ce 100644 --- a/src/curve/curve_summation.rs +++ b/src/curve/curve_summation.rs @@ -186,50 +186,51 @@ pub fn affine_multisummation_batch_inversion( #[cfg(test)] mod tests { - use crate::{ - affine_summation_batch_inversion, affine_summation_pairwise, Bls12377, Curve, - ProjectivePoint, + use crate::curve::curve_summation::{ + affine_summation_batch_inversion, affine_summation_pairwise, }; + use crate::curve::curve_types::{Curve, ProjectivePoint}; + use crate::curve::secp256k1_curve::Secp256K1; #[test] fn test_pairwise_affine_summation() { - let g_affine = Bls12377::GENERATOR_AFFINE; + let g_affine = Secp256K1::GENERATOR_AFFINE; let g2_affine = (g_affine + g_affine).to_affine(); let g3_affine = (g_affine + g_affine + g_affine).to_affine(); let g2_proj = g2_affine.to_projective(); let g3_proj = g3_affine.to_projective(); assert_eq!( - affine_summation_pairwise::(vec![g_affine, g_affine]), + affine_summation_pairwise::(vec![g_affine, g_affine]), g2_proj ); assert_eq!( - affine_summation_pairwise::(vec![g_affine, g2_affine]), + affine_summation_pairwise::(vec![g_affine, g2_affine]), g3_proj ); assert_eq!( - affine_summation_pairwise::(vec![g_affine, g_affine, g_affine]), + affine_summation_pairwise::(vec![g_affine, g_affine, g_affine]), g3_proj ); assert_eq!( - affine_summation_pairwise::(vec![]), + affine_summation_pairwise::(vec![]), ProjectivePoint::ZERO ); } #[test] fn test_pairwise_affine_summation_batch_inversion() { - let g = Bls12377::GENERATOR_AFFINE; + let g = Secp256K1::GENERATOR_AFFINE; let g_proj = g.to_projective(); assert_eq!( - affine_summation_batch_inversion::(vec![g, g]), + affine_summation_batch_inversion::(vec![g, g]), g_proj + g_proj ); assert_eq!( - affine_summation_batch_inversion::(vec![g, g, g]), + affine_summation_batch_inversion::(vec![g, g, g]), g_proj + g_proj + g_proj ); assert_eq!( - affine_summation_batch_inversion::(vec![]), + affine_summation_batch_inversion::(vec![]), ProjectivePoint::ZERO ); } diff --git a/src/gadgets/secp256k1.rs b/src/gadgets/secp256k1.rs index 36d8d145..3294a954 100644 --- a/src/gadgets/secp256k1.rs +++ b/src/gadgets/secp256k1.rs @@ -4,29 +4,53 @@ use crate::field::field_types::RichField; use crate::gadgets::nonnative::ForeignFieldTarget; use crate::plonk::circuit_builder::CircuitBuilder; +/// A Target representing an affine point on the curve `C`. #[derive(Clone, Debug)] pub struct AffinePointTarget { - pub x: ForeignFieldTarget, - pub y: ForeignFieldTarget, + pub x: ForeignFieldTarget, + pub y: ForeignFieldTarget, } impl AffinePointTarget { - pub fn to_vec(&self) -> Vec> { + pub fn to_vec(&self) -> Vec> { vec![self.x.clone(), self.y.clone()] } } impl, const D: usize> CircuitBuilder { - pub fn constant_affine_point>( + pub fn constant_affine_point( &mut self, - point: AffinePoint, + point: AffinePoint, ) -> AffinePointTarget { debug_assert!(!point.zero); AffinePointTarget { - x: self.constant_ff(point.x), - y: self.constant_ff(point.y), + x: self.constant_nonnative(point.x), + y: self.constant_nonnative(point.y), } } + + pub fn connect_affine_point( + &mut self, + lhs: AffinePointTarget, + rhs: AffinePointTarget, + ) { + self.connect_nonnative(&lhs.x, &rhs.x); + self.connect_nonnative(&lhs.y, &rhs.y); + } + + pub fn curve_assert_valid(&mut self, p: AffinePointTarget) { + let a = self.constant_nonnative(C::A); + let b = self.constant_nonnative(C::B); + + let y_squared = self.mul_nonnative(&p.y, &p.y); + let x_squared = self.mul_nonnative(&p.x, &p.x); + let x_cubed = self.mul_nonnative(&x_squared, &p.x); + let a_x = self.mul_nonnative(&a, &p.x); + let a_x_plus_b = self.add_nonnative(&a_x, &b); + let rhs = self.add_nonnative(&x_cubed, &a_x_plus_b); + + self.connect_nonnative(&y_squared, &rhs); + } } mod tests {} From d1ad3fdbad00aae07379b91e5a52ea2890c46c65 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 1 Nov 2021 16:53:48 -0700 Subject: [PATCH 151/202] fix: generator value --- src/curve/secp256k1_curve.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/curve/secp256k1_curve.rs b/src/curve/secp256k1_curve.rs index 21340c64..89497a7d 100644 --- a/src/curve/secp256k1_curve.rs +++ b/src/curve/secp256k1_curve.rs @@ -14,7 +14,7 @@ impl Curve for Secp256K1 { type ScalarField = Secp256K1Scalar; const A: Secp256K1Base = Secp256K1Base::ZERO; - const B: Secp256K1Base = Secp256K1Base::ONE; + const B: Secp256K1Base = Secp256K1Base([7, 0, 0, 0]); const GENERATOR_AFFINE: AffinePoint = AffinePoint { x: SECP256K1_GENERATOR_X, y: SECP256K1_GENERATOR_Y, From a5f21de0beade0f5b966b33151e7f3616e78b4d9 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 1 Nov 2021 17:52:52 -0700 Subject: [PATCH 152/202] fixed curve_summation tests --- src/curve/secp256k1_curve.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/curve/secp256k1_curve.rs b/src/curve/secp256k1_curve.rs index 89497a7d..1e721640 100644 --- a/src/curve/secp256k1_curve.rs +++ b/src/curve/secp256k1_curve.rs @@ -39,6 +39,8 @@ const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([ #[cfg(test)] mod tests { + use num::BigUint; + use crate::curve::curve_types::{Curve, ProjectivePoint}; use crate::curve::secp256k1_curve::Secp256K1; use crate::field::field_types::Field; @@ -52,7 +54,7 @@ mod tests { p.double(), p.to_projective().double().to_affine()); } - } + }*/ #[test] fn test_naive_multiplication() { @@ -65,7 +67,9 @@ mod tests { #[test] fn test_g1_multiplication() { - let lhs = Secp256K1Scalar::from_canonical([11111111, 22222222, 33333333, 44444444]); + let lhs = Secp256K1Scalar::from_biguint( + BigUint::from_slice(&[1111, 2222, 3333, 4444, 5555, 6666, 7777, 8888]) + ); assert_eq!(Secp256K1::convert(lhs) * Secp256K1::GENERATOR_PROJECTIVE, mul_naive(lhs, Secp256K1::GENERATOR_PROJECTIVE)); } @@ -74,7 +78,7 @@ mod tests { fn mul_naive(lhs: Secp256K1Scalar, rhs: ProjectivePoint) -> ProjectivePoint { let mut g = rhs; let mut sum = ProjectivePoint::ZERO; - for limb in lhs.to_canonical().iter() { + for limb in lhs.to_biguint().to_u64_digits().iter() { for j in 0..64 { if (limb >> j & 1u64) != 0u64 { sum = sum + g; @@ -83,5 +87,5 @@ mod tests { } } sum - }*/ + } } From f11fe2a92889f0760903c4182a3755dbc6f3b776 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 1 Nov 2021 17:53:01 -0700 Subject: [PATCH 153/202] fmt --- src/curve/secp256k1_curve.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/curve/secp256k1_curve.rs b/src/curve/secp256k1_curve.rs index 1e721640..7b84855b 100644 --- a/src/curve/secp256k1_curve.rs +++ b/src/curve/secp256k1_curve.rs @@ -67,15 +67,21 @@ mod tests { #[test] fn test_g1_multiplication() { - let lhs = Secp256K1Scalar::from_biguint( - BigUint::from_slice(&[1111, 2222, 3333, 4444, 5555, 6666, 7777, 8888]) + let lhs = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + 1111, 2222, 3333, 4444, 5555, 6666, 7777, 8888, + ])); + assert_eq!( + Secp256K1::convert(lhs) * Secp256K1::GENERATOR_PROJECTIVE, + mul_naive(lhs, Secp256K1::GENERATOR_PROJECTIVE) ); - assert_eq!(Secp256K1::convert(lhs) * Secp256K1::GENERATOR_PROJECTIVE, mul_naive(lhs, Secp256K1::GENERATOR_PROJECTIVE)); } /// A simple, somewhat inefficient implementation of multiplication which is used as a reference /// for correctness. - fn mul_naive(lhs: Secp256K1Scalar, rhs: ProjectivePoint) -> ProjectivePoint { + fn mul_naive( + lhs: Secp256K1Scalar, + rhs: ProjectivePoint, + ) -> ProjectivePoint { let mut g = rhs; let mut sum = ProjectivePoint::ZERO; for limb in lhs.to_biguint().to_u64_digits().iter() { From 0e6c5bb80c98ed601b86cbbe8d7b38c39227dd9d Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 2 Nov 2021 12:03:24 -0700 Subject: [PATCH 154/202] curve gadget changes --- src/curve/curve_summation.rs | 2 +- src/curve/mod.rs | 2 +- .../{secp256k1_curve.rs => secp256k1.rs} | 2 +- src/gadgets/{secp256k1.rs => curve.rs} | 42 ++++++++++++++++++- src/gadgets/mod.rs | 2 +- 5 files changed, 45 insertions(+), 5 deletions(-) rename src/curve/{secp256k1_curve.rs => secp256k1.rs} (98%) rename src/gadgets/{secp256k1.rs => curve.rs} (63%) diff --git a/src/curve/curve_summation.rs b/src/curve/curve_summation.rs index ad4232ce..8f347eda 100644 --- a/src/curve/curve_summation.rs +++ b/src/curve/curve_summation.rs @@ -190,7 +190,7 @@ mod tests { affine_summation_batch_inversion, affine_summation_pairwise, }; use crate::curve::curve_types::{Curve, ProjectivePoint}; - use crate::curve::secp256k1_curve::Secp256K1; + use crate::curve::secp256k1::Secp256K1; #[test] fn test_pairwise_affine_summation() { diff --git a/src/curve/mod.rs b/src/curve/mod.rs index c65f2acd..6555404e 100644 --- a/src/curve/mod.rs +++ b/src/curve/mod.rs @@ -2,4 +2,4 @@ pub mod curve_adds; pub mod curve_multiplication; pub mod curve_summation; pub mod curve_types; -pub mod secp256k1_curve; +pub mod secp256k1; diff --git a/src/curve/secp256k1_curve.rs b/src/curve/secp256k1.rs similarity index 98% rename from src/curve/secp256k1_curve.rs rename to src/curve/secp256k1.rs index 7b84855b..2fa476e1 100644 --- a/src/curve/secp256k1_curve.rs +++ b/src/curve/secp256k1.rs @@ -42,7 +42,7 @@ mod tests { use num::BigUint; use crate::curve::curve_types::{Curve, ProjectivePoint}; - use crate::curve::secp256k1_curve::Secp256K1; + use crate::curve::secp256k1::Secp256K1; use crate::field::field_types::Field; use crate::field::secp256k1_scalar::Secp256K1Scalar; diff --git a/src/gadgets/secp256k1.rs b/src/gadgets/curve.rs similarity index 63% rename from src/gadgets/secp256k1.rs rename to src/gadgets/curve.rs index 3294a954..83f73a3f 100644 --- a/src/gadgets/secp256k1.rs +++ b/src/gadgets/curve.rs @@ -51,6 +51,46 @@ impl, const D: usize> CircuitBuilder { self.connect_nonnative(&y_squared, &rhs); } + + pub fn curve_neg(&mut self, p: AffinePointTarget) { + let neg_y = self.neg_nonnative(p.y); + AffinePointTarget { + x: p.x, + y: neg_y, + } + } } -mod tests {} +mod tests { + use anyhow::Result; + + + + #[test] + fn test_curve_gadget_is_valid() -> Result<()> { + type F = CrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let + + let lst: Vec = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect(); + let a: Vec> = lst[..] + .chunks(2) + .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) + .collect(); + let mut b = a.clone(); + b.shuffle(&mut thread_rng()); + + builder.assert_permutation(a, b); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 42b3044c..2518e1ab 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -11,7 +11,7 @@ pub mod permutation; pub mod polynomial; pub mod random_access; pub mod range_check; -pub mod secp256k1; +pub mod curve; pub mod select; pub mod sorting; pub mod split_base; From 86573fc65c83f905a817a90e43399d84f5aff7ac Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 9 Nov 2021 17:51:04 -0800 Subject: [PATCH 155/202] resolve --- src/curve/secp256k1.rs | 15 ++++++++++- src/gadgets/curve.rs | 61 +++++++++++++++++++++++++++++------------- src/gadgets/mod.rs | 2 +- 3 files changed, 58 insertions(+), 20 deletions(-) diff --git a/src/curve/secp256k1.rs b/src/curve/secp256k1.rs index 2fa476e1..7102b5c9 100644 --- a/src/curve/secp256k1.rs +++ b/src/curve/secp256k1.rs @@ -41,11 +41,24 @@ const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([ mod tests { use num::BigUint; - use crate::curve::curve_types::{Curve, ProjectivePoint}; + use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; use crate::curve::secp256k1::Secp256K1; use crate::field::field_types::Field; use crate::field::secp256k1_scalar::Secp256K1Scalar; + #[test] + fn test_generator() { + let g = Secp256K1::GENERATOR_AFFINE; + assert!(g.is_valid()); + + let neg_g = AffinePoint:: { + x: g.x, + y: -g.y, + zero: g.zero, + }; + assert!(neg_g.is_valid()); + } + /*#[test] fn test_double_affine() { for i in 0..100 { diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index 83f73a3f..2c617b20 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -52,22 +52,27 @@ impl, const D: usize> CircuitBuilder { self.connect_nonnative(&y_squared, &rhs); } - pub fn curve_neg(&mut self, p: AffinePointTarget) { - let neg_y = self.neg_nonnative(p.y); - AffinePointTarget { - x: p.x, - y: neg_y, - } + pub fn curve_neg(&mut self, p: AffinePointTarget) -> AffinePointTarget { + let neg_y = self.neg_nonnative(&p.y); + AffinePointTarget { x: p.x, y: neg_y } } } mod tests { use anyhow::Result; - + use crate::curve::curve_types::{AffinePoint, Curve}; + use crate::curve::secp256k1::Secp256K1; + use crate::field::crandall_field::CrandallField; + use crate::field::field_types::Field; + use crate::field::secp256k1_base::Secp256K1Base; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; #[test] - fn test_curve_gadget_is_valid() -> Result<()> { + fn test_curve_point_is_valid() -> Result<()> { type F = CrandallField; const D: usize = 4; @@ -76,21 +81,41 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let + let g = Secp256K1::GENERATOR_AFFINE; + let g_target = builder.constant_affine_point(g); - let lst: Vec = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect(); - let a: Vec> = lst[..] - .chunks(2) - .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) - .collect(); - let mut b = a.clone(); - b.shuffle(&mut thread_rng()); - - builder.assert_permutation(a, b); + builder.curve_assert_valid(g_target); let data = builder.build(); let proof = data.prove(pw).unwrap(); verify(proof, &data.verifier_only, &data.common) } + + #[test] + #[should_panic] + fn test_curve_point_is_not_valid() { + type F = CrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let not_g = AffinePoint:: { + x: g.x, + y: g.y + Secp256K1Base::ONE, + zero: g.zero, + }; + let g_target = builder.constant_affine_point(not_g); + + builder.curve_assert_valid(g_target); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common).unwrap(); + } } diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 2518e1ab..09acb9de 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -2,6 +2,7 @@ pub mod arithmetic; pub mod arithmetic_extension; pub mod arithmetic_u32; pub mod biguint; +pub mod curve; pub mod hash; pub mod insert; pub mod interpolation; @@ -11,7 +12,6 @@ pub mod permutation; pub mod polynomial; pub mod random_access; pub mod range_check; -pub mod curve; pub mod select; pub mod sorting; pub mod split_base; From fa480854fec755afc61ccc12130fe375cb13dc8a Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 2 Nov 2021 15:04:53 -0700 Subject: [PATCH 156/202] updates --- src/gadgets/curve.rs | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index 2c617b20..42bfd8a9 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -31,14 +31,14 @@ impl, const D: usize> CircuitBuilder { pub fn connect_affine_point( &mut self, - lhs: AffinePointTarget, - rhs: AffinePointTarget, + lhs: &AffinePointTarget, + rhs: &AffinePointTarget, ) { self.connect_nonnative(&lhs.x, &rhs.x); self.connect_nonnative(&lhs.y, &rhs.y); } - pub fn curve_assert_valid(&mut self, p: AffinePointTarget) { + pub fn curve_assert_valid(&mut self, p: &AffinePointTarget) { let a = self.constant_nonnative(C::A); let b = self.constant_nonnative(C::B); @@ -52,9 +52,12 @@ impl, const D: usize> CircuitBuilder { self.connect_nonnative(&y_squared, &rhs); } - pub fn curve_neg(&mut self, p: AffinePointTarget) -> AffinePointTarget { + pub fn curve_neg(&mut self, p: &AffinePointTarget) -> AffinePointTarget { let neg_y = self.neg_nonnative(&p.y); - AffinePointTarget { x: p.x, y: neg_y } + AffinePointTarget { + x: p.x.clone(), + y: neg_y, + } } } @@ -83,8 +86,10 @@ mod tests { let g = Secp256K1::GENERATOR_AFFINE; let g_target = builder.constant_affine_point(g); + let neg_g_target = builder.curve_neg(&g_target); - builder.curve_assert_valid(g_target); + builder.curve_assert_valid(&g_target); + builder.curve_assert_valid(&neg_g_target); let data = builder.build(); let proof = data.prove(pw).unwrap(); @@ -109,9 +114,9 @@ mod tests { y: g.y + Secp256K1Base::ONE, zero: g.zero, }; - let g_target = builder.constant_affine_point(not_g); + let not_g_target = builder.constant_affine_point(not_g); - builder.curve_assert_valid(g_target); + builder.curve_assert_valid(¬_g_target); let data = builder.build(); let proof = data.prove(pw).unwrap(); From 4d4605af1f31e80c9b933a7f75d09a5e11d84526 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 11:50:25 -0800 Subject: [PATCH 157/202] merge --- src/gadgets/curve.rs | 32 ++++++++++++++++++++++++ src/gadgets/nonnative.rs | 54 ++++++++++++++++++++++++++++++++++++++-- src/iop/generator.rs | 5 ++++ src/iop/witness.rs | 6 +++++ 4 files changed, 95 insertions(+), 2 deletions(-) diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index 42bfd8a9..eeb966a5 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -59,6 +59,38 @@ impl, const D: usize> CircuitBuilder { y: neg_y, } } + + pub fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { + let AffinePointTarget { x, y } = p; + let double_y = self.add_nonnative(y, y); + let inv_double_y = self.inv_nonnative(&double_y); + let x_squared = self.mul_nonnative(x, x); + let double_x_squared = self.add_nonnative(&x_squared, &x_squared); + let triple_x_squared = self.add_nonnative(&double_x_squared, &x_squared); + + let a = self.constant_nonnative(C::A); + let triple_xx_a = self.add_nonnative(&triple_x_squared, &a); + let lambda = self.mul_nonnative(&triple_xx_a, &inv_double_y); + let lambda_squared = self.mul_nonnative(&lambda, &lambda); + let x_double = self.add_nonnative(x, x); + + let x3 = self.sub_nonnative(&lambda_squared, &x_double); + + let x_diff = self.sub_nonnative(x, &x3); + let lambda_x_diff = self.mul_nonnative(&lambda, &x_diff); + + let y3 = self.sub_nonnative(&lambda_x_diff, y); + + AffinePointTarget { x: x3, y: y3 } + } + + pub fn curve_add( + &mut self, + a: &AffinePointTarget, + b: &AffinePointTarget, + ) -> AffinePointTarget { + todo!() + } } mod tests { diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index bc090cd5..d6b20f14 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -5,6 +5,9 @@ use num::{BigUint, One}; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gadgets::biguint::BigUintTarget; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; pub struct NonNativeTarget { @@ -46,6 +49,7 @@ impl, const D: usize> CircuitBuilder { ) -> NonNativeTarget { let result = self.add_biguint(&a.value, &b.value); + // TODO: reduce add result with only one conditional subtraction self.reduce(&result) } @@ -84,8 +88,32 @@ impl, const D: usize> CircuitBuilder { self.mul_nonnative(&neg_one_ff, x) } - /// Returns `x % |FF|` as a `NonNativeTarget`. - fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget { + pub fn inv_nonnative( + &mut self, + x: &ForeignFieldTarget, + ) -> ForeignFieldTarget { + let num_limbs = x.value.num_limbs(); + let inv_biguint = self.add_virtual_biguint_target(num_limbs); + let inv = ForeignFieldTarget:: { + value: inv_biguint, + _phantom: PhantomData, + }; + + self.add_simple_generator(NonNativeInverseGenerator:: { + x: x.clone(), + inv: inv.clone(), + _phantom: PhantomData, + }); + + let product = self.mul_nonnative(&x, &inv); + let one = self.constant_nonnative(FF::ONE); + self.connect_nonnative_reduced(&product, &one); + + inv + } + + /// Returns `x % |FF|` as a `ForeignFieldTarget`. + fn reduce(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { let modulus = FF::order(); let order_target = self.constant_biguint(&modulus); let value = self.rem_biguint(x, &order_target); @@ -106,6 +134,28 @@ impl, const D: usize> CircuitBuilder { } } +#[derive(Debug)] +struct NonNativeInverseGenerator, const D: usize, FF: Field> { + x: ForeignFieldTarget, + inv: ForeignFieldTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: Field> SimpleGenerator + for NonNativeInverseGenerator +{ + fn dependencies(&self) -> Vec { + self.x.value.limbs.iter().map(|&l| l.0).collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let x = witness.get_nonnative_target(self.x.clone()); + let inv = x.inverse(); + + out_buffer.set_nonnative_target(self.inv.clone(), inv); + } +} + #[cfg(test)] mod tests { use anyhow::Result; diff --git a/src/iop/generator.rs b/src/iop/generator.rs index ae973d7c..ff7f66e0 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -8,6 +8,7 @@ use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::{Field, RichField}; use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::biguint::BigUintTarget; +use crate::gadgets::nonnative::ForeignFieldTarget; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -168,6 +169,10 @@ impl GeneratedValues { } } + pub fn set_nonnative_target(&mut self, target: ForeignFieldTarget, value: FF) { + self.set_biguint_target(target.value, value.to_biguint()) + } + pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) { ht.elements .iter() diff --git a/src/iop/witness.rs b/src/iop/witness.rs index a773c1a9..d6f4fb59 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -6,6 +6,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::Field; use crate::gadgets::biguint::BigUintTarget; +use crate::gadgets::nonnative::ForeignFieldTarget; use crate::hash::hash_types::HashOutTarget; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; @@ -68,6 +69,11 @@ pub trait Witness { result } + fn get_nonnative_target(&self, target: ForeignFieldTarget) -> FF { + let val = self.get_biguint_target(target.value); + FF::from_biguint(val) + } + fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { HashOut { elements: self.get_targets(&ht.elements).try_into().unwrap(), From a4b7772c3456e38322e84ffd10621f3ba6a6bf01 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 9 Nov 2021 17:51:38 -0800 Subject: [PATCH 158/202] resolve --- src/gadgets/biguint.rs | 2 +- src/gadgets/curve.rs | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 05a5406e..fb7eb4e0 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -110,8 +110,8 @@ impl, const D: usize> CircuitBuilder { // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. pub fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { - let num_limbs = a.limbs.len(); let (a, b) = self.pad_biguints(a, b); + let num_limbs = a.limbs.len(); let mut result_limbs = vec![]; diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index eeb966a5..abb1b39a 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -155,4 +155,30 @@ mod tests { verify(proof, &data.verifier_only, &data.common).unwrap(); } + + #[test] + fn test_curve_double() -> Result<()> { + type F = CrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let g_target = builder.constant_affine_point(g); + let neg_g_target = builder.curve_neg(&g_target); + + let double_g = builder.curve_double(&g_target); + let double_neg_g = builder.curve_double(&neg_g_target); + + builder.curve_assert_valid(&double_g); + builder.curve_assert_valid(&double_neg_g); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } } From dfad7708af81187c9636fce11685e34de9ee6f6d Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 11:50:58 -0800 Subject: [PATCH 159/202] merge --- src/gadgets/biguint.rs | 7 ++++--- src/gadgets/curve.rs | 23 +++++++++++++++++------ src/gadgets/nonnative.rs | 31 +++++++++++++++++++++++++++---- 3 files changed, 48 insertions(+), 13 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index fb7eb4e0..2d7ed693 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -160,9 +160,10 @@ impl, const D: usize> CircuitBuilder { a: &BigUintTarget, b: &BigUintTarget, ) -> (BigUintTarget, BigUintTarget) { - let num_limbs = a.limbs.len(); - let div = self.add_virtual_biguint_target(num_limbs); - let rem = self.add_virtual_biguint_target(num_limbs); + let a_len = a.limbs.len(); + let b_len = b.limbs.len(); + let div = self.add_virtual_biguint_target(a_len - b_len + 1); + let rem = self.add_virtual_biguint_target(b_len); self.add_simple_generator(BigUintDivRemGenerator:: { a: a.clone(), diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index abb1b39a..bf9a1ac0 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -1,6 +1,6 @@ use crate::curve::curve_types::{AffinePoint, Curve}; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; +use crate::field::field_types::{Field, RichField}; use crate::gadgets::nonnative::ForeignFieldTarget; use crate::plonk::circuit_builder::CircuitBuilder; @@ -60,7 +60,11 @@ impl, const D: usize> CircuitBuilder { } } - pub fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { + pub fn curve_double( + &mut self, + p: &AffinePointTarget, + p_orig: AffinePoint, + ) -> AffinePointTarget { let AffinePointTarget { x, y } = p; let double_y = self.add_nonnative(y, y); let inv_double_y = self.inv_nonnative(&double_y); @@ -94,6 +98,8 @@ impl, const D: usize> CircuitBuilder { } mod tests { + use std::ops::Neg; + use anyhow::Result; use crate::curve::curve_types::{AffinePoint, Curve}; @@ -167,14 +173,19 @@ mod tests { let mut builder = CircuitBuilder::::new(config); let g = Secp256K1::GENERATOR_AFFINE; + let neg_g = g.neg(); let g_target = builder.constant_affine_point(g); let neg_g_target = builder.curve_neg(&g_target); - let double_g = builder.curve_double(&g_target); - let double_neg_g = builder.curve_double(&neg_g_target); + let double_g = g.double(); + let double_g_other_target = builder.constant_affine_point(double_g); + builder.curve_assert_valid(&double_g_other_target); - builder.curve_assert_valid(&double_g); - builder.curve_assert_valid(&double_neg_g); + let double_g_target = builder.curve_double(&g_target, g); + let double_neg_g_target = builder.curve_double(&neg_g_target, neg_g); + + builder.curve_assert_valid(&double_g_target); + builder.curve_assert_valid(&double_neg_g_target); let data = builder.build(); let proof = data.prove(pw).unwrap(); diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index d6b20f14..a3fa7fcd 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -81,6 +81,7 @@ impl, const D: usize> CircuitBuilder { &mut self, x: &NonNativeTarget, ) -> NonNativeTarget { + // TODO: zero - x would be more efficient but doesn't seem to work? let neg_one = FF::order() - BigUint::one(); let neg_one_target = self.constant_biguint(&neg_one); let neg_one_ff = self.biguint_to_nonnative(&neg_one_target); @@ -90,11 +91,11 @@ impl, const D: usize> CircuitBuilder { pub fn inv_nonnative( &mut self, - x: &ForeignFieldTarget, - ) -> ForeignFieldTarget { + x: &NonNativeTarget, + ) -> NonNativeTarget { let num_limbs = x.value.num_limbs(); let inv_biguint = self.add_virtual_biguint_target(num_limbs); - let inv = ForeignFieldTarget:: { + let inv = NonNativeTarget:: { value: inv_biguint, _phantom: PhantomData, }; @@ -107,7 +108,7 @@ impl, const D: usize> CircuitBuilder { let product = self.mul_nonnative(&x, &inv); let one = self.constant_nonnative(FF::ONE); - self.connect_nonnative_reduced(&product, &one); + self.connect_nonnative(&product, &one); inv } @@ -264,4 +265,26 @@ mod tests { let proof = data.prove(pw).unwrap(); verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_nonnative_inv() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let inv_x_ff = x_ff.inverse(); + + type F = CrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let inv_x = builder.inv_nonnative(&x); + + let inv_x_expected = builder.constant_nonnative(inv_x_ff); + builder.connect_nonnative(&inv_x, &inv_x_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } } From 051b79db2c383336825ea1041d07cb36a4e47238 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 8 Nov 2021 15:17:12 -0800 Subject: [PATCH 160/202] curve_add_two_affine --- src/gadgets/curve.rs | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index bf9a1ac0..a3fd5b90 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -88,12 +88,38 @@ impl, const D: usize> CircuitBuilder { AffinePointTarget { x: x3, y: y3 } } - pub fn curve_add( + pub fn curve_add_two_affine( &mut self, - a: &AffinePointTarget, - b: &AffinePointTarget, + p1: &AffinePointTarget, + p2: &AffinePointTarget, ) -> AffinePointTarget { - todo!() + let AffinePointTarget { x: x1, y: y1 } = p1; + let AffinePointTarget { x: x2, y: y2 } = p2; + + let u = self.sub_nonnative(y2, y1); + let uu = self.mul_nonnative(&u, &u); + let v = self.sub_nonnative(x2, x1); + let vv = self.mul_nonnative(&v, &v); + let vvv = self.mul_nonnative(&v, &vv); + let r = self.mul_nonnative(&vv, x1); + let diff = self.sub_nonnative(&uu, &vvv); + let r2 = self.add_nonnative(&r, &r); + let a = self.sub_nonnative(&diff, &r2); + let x3 = self.mul_nonnative(&v, &a); + + let r_a = self.sub_nonnative(&r, &a); + let y3_first = self.mul_nonnative(&u, &r_a); + let y3_second = self.mul_nonnative(&vvv, y1); + let y3 = self.sub_nonnative(&y3_first, &y3_second); + + let z3_inv = self.inv_nonnative(&vvv); + let x3_norm = self.mul_nonnative(&x3, &z3_inv); + let y3_norm = self.mul_nonnative(&y3, &z3_inv); + + AffinePointTarget { + x: x3_norm, + y: y3_norm, + } } } From d6630869e124bf929a0d523952cb0460ea5ae8ab Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 9 Nov 2021 16:23:22 -0800 Subject: [PATCH 161/202] msm (outside circuit) --- src/curve/curve_msm.rs | 263 +++++++++++++++++++++++++++++++++++++++++ src/curve/mod.rs | 1 + src/gadgets/curve.rs | 2 +- 3 files changed, 265 insertions(+), 1 deletion(-) create mode 100644 src/curve/curve_msm.rs diff --git a/src/curve/curve_msm.rs b/src/curve/curve_msm.rs new file mode 100644 index 00000000..d2cb8049 --- /dev/null +++ b/src/curve/curve_msm.rs @@ -0,0 +1,263 @@ +use itertools::Itertools; +use rayon::prelude::*; + +use crate::curve::curve_summation::affine_multisummation_best; +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; +use crate::field::field_types::Field; + +/// In Yao's method, we compute an affine summation for each digit. In a parallel setting, it would +/// be easiest to assign individual summations to threads, but this would be sub-optimal because +/// multi-summations can be more efficient than repeating individual summations (see +/// `affine_multisummation_best`). Thus we divide digits into large chunks, and assign chunks of +/// digits to threads. Note that there is a delicate balance here, as large chunks can result in +/// uneven distributions of work among threads. +const DIGITS_PER_CHUNK: usize = 80; + +#[derive(Clone, Debug)] +pub struct MsmPrecomputation { + /// For each generator (in the order they were passed to `msm_precompute`), contains a vector + /// of powers, i.e. [(2^w)^i] for i < DIGITS. + // TODO: Use compressed coordinates here. + powers_per_generator: Vec>>, + + /// The window size. + w: usize, +} + +pub fn msm_precompute( + generators: &[ProjectivePoint], + w: usize, +) -> MsmPrecomputation { + MsmPrecomputation { + powers_per_generator: generators + .into_par_iter() + .map(|&g| precompute_single_generator(g, w)) + .collect(), + w, + } +} + +fn precompute_single_generator(g: ProjectivePoint, w: usize) -> Vec> { + let digits = (C::ScalarField::BITS + w - 1) / w; + let mut powers: Vec> = Vec::with_capacity(digits); + powers.push(g); + for i in 1..digits { + let mut power_i_proj = powers[i - 1]; + for _j in 0..w { + power_i_proj = power_i_proj.double(); + } + powers.push(power_i_proj); + } + ProjectivePoint::batch_to_affine(&powers) +} + +pub fn msm_parallel( + scalars: &[C::ScalarField], + generators: &[ProjectivePoint], + w: usize, +) -> ProjectivePoint { + let precomputation = msm_precompute(generators, w); + msm_execute_parallel(&precomputation, scalars) +} + +pub fn msm_execute( + precomputation: &MsmPrecomputation, + scalars: &[C::ScalarField], +) -> ProjectivePoint { + assert_eq!(precomputation.powers_per_generator.len(), scalars.len()); + let w = precomputation.w; + let digits = (C::ScalarField::BITS + w - 1) / w; + let base = 1 << w; + + // This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use + // extremely large windows, the repeated scans in Yao's method could be more expensive than the + // actual group operations. To avoid this, we store a multimap from each possible digit to the + // positions in which that digit occurs in the scalars. These positions have the form (i, j), + // where i is the index of the generator and j is an index into the digits of the scalar + // associated with that generator. + let mut digit_occurrences: Vec> = Vec::with_capacity(digits); + for _i in 0..base { + digit_occurrences.push(Vec::new()); + } + for (i, scalar) in scalars.iter().enumerate() { + let digits = to_digits::(scalar, w); + for (j, &digit) in digits.iter().enumerate() { + digit_occurrences[digit].push((i, j)); + } + } + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + + for digit in (1..base).rev() { + for &(i, j) in &digit_occurrences[digit] { + u = u + precomputation.powers_per_generator[i][j]; + } + y = y + u; + } + + y +} + +pub fn msm_execute_parallel( + precomputation: &MsmPrecomputation, + scalars: &[C::ScalarField], +) -> ProjectivePoint { + assert_eq!(precomputation.powers_per_generator.len(), scalars.len()); + let w = precomputation.w; + let digits = (C::ScalarField::BITS + w - 1) / w; + let base = 1 << w; + + // This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use + // extremely large windows, the repeated scans in Yao's method could be more expensive than the + // actual group operations. To avoid this, we store a multimap from each possible digit to the + // positions in which that digit occurs in the scalars. These positions have the form (i, j), + // where i is the index of the generator and j is an index into the digits of the scalar + // associated with that generator. + let mut digit_occurrences: Vec> = Vec::with_capacity(digits); + for _i in 0..base { + digit_occurrences.push(Vec::new()); + } + for (i, scalar) in scalars.iter().enumerate() { + let digits = to_digits::(scalar, w); + for (j, &digit) in digits.iter().enumerate() { + digit_occurrences[digit].push((i, j)); + } + } + + // For each digit, we add up the powers associated with all occurrences that digit. + let digits: Vec = (0..base).collect(); + let digit_acc: Vec> = digits + .par_chunks(DIGITS_PER_CHUNK) + .flat_map(|chunk| { + let summations: Vec>> = chunk + .iter() + .map(|&digit| { + digit_occurrences[digit] + .iter() + .map(|&(i, j)| precomputation.powers_per_generator[i][j]) + .collect() + }) + .collect(); + affine_multisummation_best(summations) + }) + .collect(); + // println!("Computing the per-digit summations (in parallel) took {}s", start.elapsed().as_secs_f64()); + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + for digit in (1..base).rev() { + u = u + digit_acc[digit]; + y = y + u; + } + // println!("Final summation (sequential) {}s", start.elapsed().as_secs_f64()); + y +} + +pub(crate) fn to_digits(x: &C::ScalarField, w: usize) -> Vec { + let scalar_bits = C::ScalarField::BITS; + let num_digits = (scalar_bits + w - 1) / w; + + // Convert x to a bool array. + let x_canonical: Vec<_> = x + .to_biguint() + .to_u64_digits() + .iter() + .cloned() + .pad_using(scalar_bits / 64, |_| 0) + .collect(); + let mut x_bits = Vec::with_capacity(scalar_bits); + for i in 0..scalar_bits { + x_bits.push((x_canonical[i / 64] >> (i as u64 % 64) & 1) != 0); + } + + let mut digits = Vec::with_capacity(num_digits); + for i in 0..num_digits { + let mut digit = 0; + for j in ((i * w)..((i + 1) * w).min(scalar_bits)).rev() { + digit <<= 1; + digit |= x_bits[j] as usize; + } + digits.push(digit); + } + digits +} + +#[cfg(test)] +mod tests { + use num::BigUint; + + use crate::curve::curve_msm::{msm_execute, msm_precompute, to_digits}; + use crate::curve::curve_types::Curve; + use crate::curve::secp256k1::Secp256K1; + use crate::field::field_types::Field; + use crate::field::secp256k1_scalar::Secp256K1Scalar; + + #[test] + fn test_to_digits() { + let x_canonical = [ + 0b10101010101010101010101010101010, + 0b10101010101010101010101010101010, + 0b11001100110011001100110011001100, + 0b11001100110011001100110011001100, + 0b11110000111100001111000011110000, + 0b11110000111100001111000011110000, + 0b00001111111111111111111111111111, + 0b11111111111111111111111111111111, + ]; + let x = Secp256K1Scalar::from_biguint(BigUint::from_slice(&x_canonical)); + assert_eq!(x.to_biguint().to_u32_digits(), x_canonical); + assert_eq!( + to_digits::(&x, 17), + vec![ + 0b01010101010101010, + 0b10101010101010101, + 0b01010101010101010, + 0b11001010101010101, + 0b01100110011001100, + 0b00110011001100110, + 0b10011001100110011, + 0b11110000110011001, + 0b01111000011110000, + 0b00111100001111000, + 0b00011110000111100, + 0b11111111111111110, + 0b01111111111111111, + 0b11111111111111000, + 0b11111111111111111, + 0b1, + ] + ); + } + + #[test] + fn test_msm() { + let w = 5; + + let generator_1 = Secp256K1::GENERATOR_PROJECTIVE; + let generator_2 = generator_1 + generator_1; + let generator_3 = generator_1 + generator_2; + + let scalar_1 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + 11111111, 22222222, 33333333, 44444444, + ])); + let scalar_2 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + 22222222, 22222222, 33333333, 44444444, + ])); + let scalar_3 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + 33333333, 22222222, 33333333, 44444444, + ])); + + let generators = vec![generator_1, generator_2, generator_3]; + let scalars = vec![scalar_1, scalar_2, scalar_3]; + + let precomputation = msm_precompute(&generators, w); + let result_msm = msm_execute(&precomputation, &scalars); + + let result_naive = Secp256K1::convert(scalar_1) * generator_1 + + Secp256K1::convert(scalar_2) * generator_2 + + Secp256K1::convert(scalar_3) * generator_3; + + assert_eq!(result_msm, result_naive); + } +} diff --git a/src/curve/mod.rs b/src/curve/mod.rs index 6555404e..d31e373e 100644 --- a/src/curve/mod.rs +++ b/src/curve/mod.rs @@ -1,4 +1,5 @@ pub mod curve_adds; +pub mod curve_msm; pub mod curve_multiplication; pub mod curve_summation; pub mod curve_types; diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index a3fd5b90..fe6ae306 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -88,7 +88,7 @@ impl, const D: usize> CircuitBuilder { AffinePointTarget { x: x3, y: y3 } } - pub fn curve_add_two_affine( + pub fn curve_add( &mut self, p1: &AffinePointTarget, p2: &AffinePointTarget, From e4b894cb12867c5e9b09c9212e284040d7184ef2 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 19 Nov 2021 15:29:48 -0800 Subject: [PATCH 162/202] merge --- src/gadgets/nonnative.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index a3fa7fcd..0f16ddc3 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -77,10 +77,7 @@ impl, const D: usize> CircuitBuilder { self.reduce(&result) } - pub fn neg_nonnative( - &mut self, - x: &NonNativeTarget, - ) -> NonNativeTarget { + pub fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { // TODO: zero - x would be more efficient but doesn't seem to work? let neg_one = FF::order() - BigUint::one(); let neg_one_target = self.constant_biguint(&neg_one); @@ -89,10 +86,7 @@ impl, const D: usize> CircuitBuilder { self.mul_nonnative(&neg_one_ff, x) } - pub fn inv_nonnative( - &mut self, - x: &NonNativeTarget, - ) -> NonNativeTarget { + pub fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { let num_limbs = x.value.num_limbs(); let inv_biguint = self.add_virtual_biguint_target(num_limbs); let inv = NonNativeTarget:: { From c7fda246ca4c3f009d419ffce5103736be247b2f Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 11:53:45 -0800 Subject: [PATCH 163/202] fixes --- src/field/field_types.rs | 4 ---- src/gadgets/curve.rs | 16 ++++++++-------- src/gadgets/nonnative.rs | 11 ++++++----- src/iop/generator.rs | 4 ++-- src/iop/witness.rs | 4 ++-- 5 files changed, 18 insertions(+), 21 deletions(-) diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 80eeecff..b7b9ddf4 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -93,10 +93,6 @@ pub trait Field: self.square() * *self } - fn double(&self) -> Self { - *self * Self::TWO - } - fn triple(&self) -> Self { *self * (Self::ONE + Self::TWO) } diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index fe6ae306..5acd38d2 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -1,18 +1,18 @@ use crate::curve::curve_types::{AffinePoint, Curve}; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; -use crate::gadgets::nonnative::ForeignFieldTarget; +use crate::gadgets::nonnative::NonNativeTarget; use crate::plonk::circuit_builder::CircuitBuilder; /// A Target representing an affine point on the curve `C`. #[derive(Clone, Debug)] pub struct AffinePointTarget { - pub x: ForeignFieldTarget, - pub y: ForeignFieldTarget, + pub x: NonNativeTarget, + pub y: NonNativeTarget, } impl AffinePointTarget { - pub fn to_vec(&self) -> Vec> { + pub fn to_vec(&self) -> Vec> { vec![self.x.clone(), self.y.clone()] } } @@ -130,8 +130,8 @@ mod tests { use crate::curve::curve_types::{AffinePoint, Curve}; use crate::curve::secp256k1::Secp256K1; - use crate::field::crandall_field::CrandallField; use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; use crate::field::secp256k1_base::Secp256K1Base; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; @@ -140,7 +140,7 @@ mod tests { #[test] fn test_curve_point_is_valid() -> Result<()> { - type F = CrandallField; + type F = GoldilocksField; const D: usize = 4; let config = CircuitConfig::large_config(); @@ -164,7 +164,7 @@ mod tests { #[test] #[should_panic] fn test_curve_point_is_not_valid() { - type F = CrandallField; + type F = GoldilocksField; const D: usize = 4; let config = CircuitConfig::large_config(); @@ -190,7 +190,7 @@ mod tests { #[test] fn test_curve_double() -> Result<()> { - type F = CrandallField; + type F = GoldilocksField; const D: usize = 4; let config = CircuitConfig::large_config(); diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 0f16ddc3..10629ad9 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -10,8 +10,9 @@ use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; +#[derive(Clone, Debug)] pub struct NonNativeTarget { - value: BigUintTarget, + pub(crate) value: BigUintTarget, _phantom: PhantomData, } @@ -107,8 +108,8 @@ impl, const D: usize> CircuitBuilder { inv } - /// Returns `x % |FF|` as a `ForeignFieldTarget`. - fn reduce(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { + /// Returns `x % |FF|` as a `NonNativeTarget`. + fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget { let modulus = FF::order(); let order_target = self.constant_biguint(&modulus); let value = self.rem_biguint(x, &order_target); @@ -131,8 +132,8 @@ impl, const D: usize> CircuitBuilder { #[derive(Debug)] struct NonNativeInverseGenerator, const D: usize, FF: Field> { - x: ForeignFieldTarget, - inv: ForeignFieldTarget, + x: NonNativeTarget, + inv: NonNativeTarget, _phantom: PhantomData, } diff --git a/src/iop/generator.rs b/src/iop/generator.rs index ff7f66e0..ea4ac1f6 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -8,7 +8,7 @@ use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::{Field, RichField}; use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::biguint::BigUintTarget; -use crate::gadgets::nonnative::ForeignFieldTarget; +use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -169,7 +169,7 @@ impl GeneratedValues { } } - pub fn set_nonnative_target(&mut self, target: ForeignFieldTarget, value: FF) { + pub fn set_nonnative_target(&mut self, target: NonNativeTarget, value: FF) { self.set_biguint_target(target.value, value.to_biguint()) } diff --git a/src/iop/witness.rs b/src/iop/witness.rs index d6f4fb59..8b6df90a 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -6,7 +6,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::Field; use crate::gadgets::biguint::BigUintTarget; -use crate::gadgets::nonnative::ForeignFieldTarget; +use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::HashOutTarget; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; @@ -69,7 +69,7 @@ pub trait Witness { result } - fn get_nonnative_target(&self, target: ForeignFieldTarget) -> FF { + fn get_nonnative_target(&self, target: NonNativeTarget) -> FF { let val = self.get_biguint_target(target.value); FF::from_biguint(val) } From f6954704d97d9922dd9f9c2e0c5957c85b00e3f3 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 11:54:25 -0800 Subject: [PATCH 164/202] fix --- src/gadgets/curve.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index 5acd38d2..3c205e2f 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -143,7 +143,7 @@ mod tests { type F = GoldilocksField; const D: usize = 4; - let config = CircuitConfig::large_config(); + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -167,7 +167,7 @@ mod tests { type F = GoldilocksField; const D: usize = 4; - let config = CircuitConfig::large_config(); + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -193,7 +193,7 @@ mod tests { type F = GoldilocksField; const D: usize = 4; - let config = CircuitConfig::large_config(); + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); From 7da99ad4d431faf39a7c0945fec4aa63c51fdaa6 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 12:59:04 -0800 Subject: [PATCH 165/202] test fixes --- src/gadgets/biguint.rs | 24 ------------------------ src/gadgets/nonnative.rs | 6 +++--- 2 files changed, 3 insertions(+), 27 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 2d7ed693..3aa96235 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -369,28 +369,4 @@ mod tests { let proof = data.prove(pw).unwrap(); verify(proof, &data.verifier_only, &data.common) } - - #[test] - fn test_biguint_sub() -> Result<()> { - let x_value = BigUint::from_u128(33333333333333333333333333333333333333).unwrap(); - let y_value = BigUint::from_u128(22222222222222222222222222222222222222).unwrap(); - let expected_z_value = &x_value - &y_value; - - type F = CrandallField; - let config = CircuitConfig::large_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_biguint(x_value); - let y = builder.constant_biguint(y_value); - let z = builder.sub_biguint(x, y); - let expected_z = builder.constant_biguint(expected_z_value); - - builder.connect_biguint(z, expected_z); - - let data = builder.build(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } } diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 10629ad9..7cae727d 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -158,7 +158,7 @@ mod tests { use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; - use crate::field::secp256k1::Secp256K1Base; + use crate::field::secp256k1_base::Secp256K1Base; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; @@ -267,8 +267,8 @@ mod tests { let x_ff = FF::rand(); let inv_x_ff = x_ff.inverse(); - type F = CrandallField; - let config = CircuitConfig::large_config(); + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); From 0f49f6461e528c6dfd855cb55d806ad8286cd0e4 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 15 Nov 2021 11:29:06 -0800 Subject: [PATCH 166/202] removed from ProjectivePoint --- src/curve/curve_adds.rs | 9 +++----- src/curve/curve_types.rs | 47 +++++++++++++++------------------------- 2 files changed, 21 insertions(+), 35 deletions(-) diff --git a/src/curve/curve_adds.rs b/src/curve/curve_adds.rs index 32a5adcc..66e66bd0 100644 --- a/src/curve/curve_adds.rs +++ b/src/curve/curve_adds.rs @@ -11,19 +11,17 @@ impl Add> for ProjectivePoint { x: x1, y: y1, z: z1, - zero: zero1, } = self; let ProjectivePoint { x: x2, y: y2, z: z2, - zero: zero2, } = rhs; - if zero1 { + if z1 == C::BaseField::ZERO { return rhs; } - if zero2 { + if z2 == C::BaseField::ZERO { return self; } @@ -66,7 +64,6 @@ impl Add> for ProjectivePoint { x: x1, y: y1, z: z1, - zero: zero1, } = self; let AffinePoint { x: x2, @@ -74,7 +71,7 @@ impl Add> for ProjectivePoint { zero: zero2, } = rhs; - if zero1 { + if z1 == C::BaseField::ZERO { return rhs.to_projective(); } if zero2 { diff --git a/src/curve/curve_types.rs b/src/curve/curve_types.rs index 830dc7c1..3c16651e 100644 --- a/src/curve/curve_types.rs +++ b/src/curve/curve_types.rs @@ -23,7 +23,6 @@ pub trait Curve: 'static + Sync + Sized + Copy + Debug { x: Self::GENERATOR_AFFINE.x, y: Self::GENERATOR_AFFINE.y, z: Self::BaseField::ONE, - zero: false, }; fn convert(x: Self::ScalarField) -> CurveScalar { @@ -89,12 +88,13 @@ impl AffinePoint { pub fn to_projective(&self) -> ProjectivePoint { let Self { x, y, zero } = *self; - ProjectivePoint { - x, - y, - z: C::BaseField::ONE, - zero, - } + let z = if zero { + C::BaseField::ZERO + } else { + C::BaseField::ONE + }; + + ProjectivePoint { x, y, z } } pub fn batch_to_projective(affine_points: &[Self]) -> Vec> { @@ -150,7 +150,6 @@ pub struct ProjectivePoint { pub x: C::BaseField, pub y: C::BaseField, pub z: C::BaseField, - pub zero: bool, } impl ProjectivePoint { @@ -158,16 +157,10 @@ impl ProjectivePoint { x: C::BaseField::ZERO, y: C::BaseField::ZERO, z: C::BaseField::ZERO, - zero: true, }; pub fn nonzero(x: C::BaseField, y: C::BaseField, z: C::BaseField) -> Self { - let point = Self { - x, - y, - z, - zero: false, - }; + let point = Self { x, y, z }; debug_assert!(point.is_valid()); point } @@ -177,8 +170,8 @@ impl ProjectivePoint { } pub fn to_affine(&self) -> AffinePoint { - let Self { x, y, z, zero } = *self; - if zero { + let Self { x, y, z } = *self; + if z == C::BaseField::ZERO { AffinePoint::ZERO } else { let z_inv = z.inverse(); @@ -193,8 +186,8 @@ impl ProjectivePoint { let mut result = Vec::with_capacity(n); for i in 0..n { - let Self { x, y, z: _, zero } = proj_points[i]; - result.push(if zero { + let Self { x, y, z } = proj_points[i]; + result.push(if z == C::BaseField::ZERO { AffinePoint::ZERO } else { let z_inv = z_invs[i]; @@ -205,8 +198,8 @@ impl ProjectivePoint { } pub fn double(&self) -> Self { - let Self { x, y, z, zero } = *self; - if zero { + let Self { x, y, z } = *self; + if z == C::BaseField::ZERO { return ProjectivePoint::ZERO; } @@ -228,7 +221,6 @@ impl ProjectivePoint { x: x3, y: y3, z: z3, - zero: false, } } @@ -245,7 +237,6 @@ impl ProjectivePoint { x: self.x, y: -self.y, z: self.z, - zero: self.zero, } } } @@ -256,16 +247,14 @@ impl PartialEq for ProjectivePoint { x: x1, y: y1, z: z1, - zero: zero1, } = *self; let ProjectivePoint { x: x2, y: y2, z: z2, - zero: zero2, } = *other; - if zero1 || zero2 { - return zero1 == zero2; + if z1 == C::BaseField::ZERO || z2 == C::BaseField::ZERO { + return z1 == z2; } // We want to compare (x1/z1, y1/z1) == (x2/z2, y2/z2). @@ -289,7 +278,7 @@ impl Neg for ProjectivePoint { type Output = ProjectivePoint; fn neg(self) -> Self::Output { - let ProjectivePoint { x, y, z, zero } = self; - ProjectivePoint { x, y: -y, z, zero } + let ProjectivePoint { x, y, z } = self; + ProjectivePoint { x, y: -y, z } } } From 70abf3e9cbcd921d92ed0ed06c69f22111252cbe Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 16 Nov 2021 14:26:50 -0800 Subject: [PATCH 167/202] addressed comments --- src/curve/curve_adds.rs | 3 +++ src/curve/curve_types.rs | 1 + src/curve/secp256k1.rs | 3 ++- src/field/secp256k1_base.rs | 2 +- src/gadgets/curve.rs | 5 +++-- src/gadgets/nonnative.rs | 10 ++++------ 6 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/curve/curve_adds.rs b/src/curve/curve_adds.rs index 66e66bd0..f25d3847 100644 --- a/src/curve/curve_adds.rs +++ b/src/curve/curve_adds.rs @@ -41,6 +41,7 @@ impl Add> for ProjectivePoint { } } + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/add-1998-cmo-2 let z1z2 = z1 * z2; let u = y2z1 - y1z2; let uu = u.square(); @@ -92,6 +93,7 @@ impl Add> for ProjectivePoint { } } + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/madd-1998-cmo let u = y2z1 - y1; let uu = u.square(); let v = x2z1 - x1; @@ -138,6 +140,7 @@ impl Add> for AffinePoint { } } + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/mmadd-1998-cmo let u = y2 - y1; let uu = u.square(); let v = x2 - x1; diff --git a/src/curve/curve_types.rs b/src/curve/curve_types.rs index 3c16651e..f2bb24b5 100644 --- a/src/curve/curve_types.rs +++ b/src/curve/curve_types.rs @@ -197,6 +197,7 @@ impl ProjectivePoint { result } + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/doubling/dbl-2007-bl pub fn double(&self) -> Self { let Self { x, y, z } = *self; if z == C::BaseField::ZERO { diff --git a/src/curve/secp256k1.rs b/src/curve/secp256k1.rs index 7102b5c9..47c6ebb2 100644 --- a/src/curve/secp256k1.rs +++ b/src/curve/secp256k1.rs @@ -22,6 +22,7 @@ impl Curve for Secp256K1 { }; } +// 55066263022277343669578718895168534326250603453777594175500187360389116729240 const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([ 0x59F2815B16F81798, 0x029BFCDB2DCE28D9, @@ -29,7 +30,7 @@ const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([ 0x79BE667EF9DCBBAC, ]); -/// 241266749859715473739788878240585681733927191168601896383759122102112907357779751001206799952863815012735208165030 +/// 32670510020758816978083085130507043184471273380659243275938904335757337482424 const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([ 0x9C47D08FFB10D4B8, 0xFD17B448A6855419, diff --git a/src/field/secp256k1_base.rs b/src/field/secp256k1_base.rs index a09edc30..acb1df4e 100644 --- a/src/field/secp256k1_base.rs +++ b/src/field/secp256k1_base.rs @@ -88,7 +88,7 @@ impl Field for Secp256K1Base { // Sage: `g = GF(p).multiplicative_generator()` const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]); - // Sage: `g_2 = power_mod(g, (p - 1) // 2), p)` + // Sage: `g_2 = g^((p - 1) / 2)` const POWER_OF_TWO_GENERATOR: Self = Self::NEG_ONE; const BITS: usize = 256; diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index 3c205e2f..5a458a56 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -1,10 +1,11 @@ use crate::curve::curve_types::{AffinePoint, Curve}; use crate::field::extension_field::Extendable; -use crate::field::field_types::{Field, RichField}; +use crate::field::field_types::RichField; use crate::gadgets::nonnative::NonNativeTarget; use crate::plonk::circuit_builder::CircuitBuilder; -/// A Target representing an affine point on the curve `C`. +/// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency, +/// so we assume these points are not zero. #[derive(Clone, Debug)] pub struct AffinePointTarget { pub x: NonNativeTarget, diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 7cae727d..90735a61 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use num::{BigUint, One}; +use num::{BigUint, One, Zero}; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; @@ -79,12 +79,10 @@ impl, const D: usize> CircuitBuilder { } pub fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { - // TODO: zero - x would be more efficient but doesn't seem to work? - let neg_one = FF::order() - BigUint::one(); - let neg_one_target = self.constant_biguint(&neg_one); - let neg_one_ff = self.biguint_to_nonnative(&neg_one_target); + let zero_target = self.constant_biguint(&BigUint::zero()); + let zero_ff = self.biguint_to_nonnative(&zero_target); - self.mul_nonnative(&neg_one_ff, x) + self.sub_nonnative(&zero_ff, x) } pub fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { From 284f9a412ca385959a8e156541276fced5e80e1c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 18 Nov 2021 10:30:57 -0800 Subject: [PATCH 168/202] curve multiply; test for curve add; addressed comments --- src/curve/curve_multiplication.rs | 10 +- src/curve/curve_types.rs | 31 +----- src/gadgets/biguint.rs | 11 ++ src/gadgets/curve.rs | 179 +++++++++++++++++++++++++++--- src/gadgets/nonnative.rs | 32 ++++++ 5 files changed, 220 insertions(+), 43 deletions(-) diff --git a/src/curve/curve_multiplication.rs b/src/curve/curve_multiplication.rs index e5ac0eb3..b09b8a0f 100644 --- a/src/curve/curve_multiplication.rs +++ b/src/curve/curve_multiplication.rs @@ -1,6 +1,6 @@ use std::ops::Mul; -use crate::curve::curve_summation::affine_summation_batch_inversion; +use crate::curve::curve_summation::affine_multisummation_batch_inversion; use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar, ProjectivePoint}; use crate::field::field_types::Field; @@ -48,6 +48,7 @@ impl ProjectivePoint { let mut y = ProjectivePoint::ZERO; let mut u = ProjectivePoint::ZERO; + let mut all_summands = Vec::new(); for j in (1..BASE).rev() { let mut u_summands = Vec::new(); for (i, &digit) in digits.iter().enumerate() { @@ -55,7 +56,12 @@ impl ProjectivePoint { u_summands.push(precomputed_powers[i]); } } - u = u + affine_summation_batch_inversion(u_summands); + all_summands.push(u_summands); + } + + let all_sums = affine_multisummation_batch_inversion(all_summands); + for i in 0..all_sums.len() { + u = u + all_sums[i]; y = y + u; } y diff --git a/src/curve/curve_types.rs b/src/curve/curve_types.rs index f2bb24b5..c9a04ab2 100644 --- a/src/curve/curve_types.rs +++ b/src/curve/curve_types.rs @@ -1,8 +1,6 @@ use std::fmt::Debug; use std::ops::Neg; -use anyhow::Result; - use crate::field::field_types::Field; // To avoid implementation conflicts from associated types, @@ -29,30 +27,6 @@ pub trait Curve: 'static + Sync + Sized + Copy + Debug { CurveScalar(x) } - /*fn try_convert_b2s(x: Self::BaseField) -> Result { - x.try_convert::() - } - - fn try_convert_s2b(x: Self::ScalarField) -> Result { - x.try_convert::() - } - - fn try_convert_s2b_slice(s: &[Self::ScalarField]) -> Result> { - let mut res = Vec::with_capacity(s.len()); - for &x in s { - res.push(Self::try_convert_s2b(x)?); - } - Ok(res) - } - - fn try_convert_b2s_slice(s: &[Self::BaseField]) -> Result> { - let mut res = Vec::with_capacity(s.len()); - for &x in s { - res.push(Self::try_convert_b2s(x)?); - } - Ok(res) - }*/ - fn is_safe_curve() -> bool { // Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0. (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()) @@ -155,7 +129,7 @@ pub struct ProjectivePoint { impl ProjectivePoint { pub const ZERO: Self = Self { x: C::BaseField::ZERO, - y: C::BaseField::ZERO, + y: C::BaseField::ONE, z: C::BaseField::ZERO, }; @@ -166,7 +140,8 @@ impl ProjectivePoint { } pub fn is_valid(&self) -> bool { - self.to_affine().is_valid() + let Self { x, y, z } = *self; + z.is_zero() || y.square() * z == x.cube() + C::A * x * z.square() + C::B * z.cube() } pub fn to_affine(&self) -> AffinePoint { diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 3aa96235..b67c85a5 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -155,6 +155,17 @@ impl, const D: usize> CircuitBuilder { } } + // Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). + pub fn mul_add_biguint( + &mut self, + x: &BigUintTarget, + y: &BigUintTarget, + z: &BigUintTarget, + ) -> BigUintTarget { + let prod = self.mul_biguint(x, y); + self.add_biguint(&prod, z) + } + pub fn div_rem_biguint( &mut self, a: &BigUintTarget, diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index 5a458a56..eda0e5e0 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -1,6 +1,6 @@ use crate::curve::curve_types::{AffinePoint, Curve}; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; +use crate::field::field_types::{Field, RichField}; use crate::gadgets::nonnative::NonNativeTarget; use crate::plonk::circuit_builder::CircuitBuilder; @@ -18,6 +18,17 @@ impl AffinePointTarget { } } +const WINDOW_BITS: usize = 4; +const BASE: usize = 1 << WINDOW_BITS; + +fn digits_per_scalar() -> usize { + (C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS +} + +pub struct MulPrecomputationTarget { + powers: Vec>, +} + impl, const D: usize> CircuitBuilder { pub fn constant_affine_point( &mut self, @@ -39,6 +50,13 @@ impl, const D: usize> CircuitBuilder { self.connect_nonnative(&lhs.y, &rhs.y); } + pub fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget { + let x = self.add_virtual_nonnative_target(); + let y = self.add_virtual_nonnative_target(); + + AffinePointTarget { x, y } + } + pub fn curve_assert_valid(&mut self, p: &AffinePointTarget) { let a = self.constant_nonnative(C::A); let b = self.constant_nonnative(C::B); @@ -61,11 +79,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn curve_double( - &mut self, - p: &AffinePointTarget, - p_orig: AffinePoint, - ) -> AffinePointTarget { + pub fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { let AffinePointTarget { x, y } = p; let double_y = self.add_nonnative(y, y); let inv_double_y = self.inv_nonnative(&double_y); @@ -89,6 +103,7 @@ impl, const D: usize> CircuitBuilder { AffinePointTarget { x: x3, y: y3 } } + // Add two points, which are assumed to be non-equal. pub fn curve_add( &mut self, p1: &AffinePointTarget, @@ -122,6 +137,110 @@ impl, const D: usize> CircuitBuilder { y: y3_norm, } } + + pub fn mul_precompute( + &mut self, + p: &AffinePointTarget, + ) -> MulPrecomputationTarget { + let num_digits = digits_per_scalar::(); + + let mut powers = Vec::with_capacity(num_digits); + powers.push(p.clone()); + for i in 1..num_digits { + let mut power_i = powers[i - 1].clone(); + for _j in 0..WINDOW_BITS { + power_i = self.curve_double(&power_i); + } + powers.push(power_i); + } + + MulPrecomputationTarget { powers } + } + + /*fn to_digits(&mut self, x: &NonNativeTarget) -> Vec> { + debug_assert!( + 64 % WINDOW_BITS == 0, + "For simplicity, only power-of-two window sizes are handled for now" + ); + + let base = self.constant_nonnative(C::ScalarField::from_canonical_u64(BASE as u64)); + + let num_digits = digits_per_scalar::(); + let mut digits = Vec::with_capacity(num_digits); + + let (rest, limb) = self.div_rem_nonnative(&x, &base); + for _ in 0..num_digits { + digits.push(limb); + + let (rest, limb) = self.div_rem_nonnative(&rest, &base); + } + + digits + } + + pub fn mul_with_precomputation( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + precomputation: MulPrecomputationTarget, + ) -> AffinePointTarget { + // Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf + let precomputed_powers = precomputation.powers; + + let digits = self.to_digits(n); + + + }*/ + + pub fn curve_scalar_mul( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget { + let one = self.constant_nonnative(C::BaseField::ONE); + let two = self.constant_nonnative(C::ScalarField::TWO); + let num_bits = C::ScalarField::BITS; + + // Result starts at p, which is later subtracted, because we don't support arithmetic with the zero point. + let mut result = self.add_virtual_affine_point_target(); + self.connect_affine_point(p, &result); + let mut two_i_times_p = self.add_virtual_affine_point_target(); + self.connect_affine_point(p, &two_i_times_p); + + let mut cur_n = self.add_virtual_nonnative_target::(); + for _i in 0..num_bits { + let (bit_scalar, new_n) = self.div_rem_nonnative(&cur_n, &two); + let bit_biguint = self.nonnative_to_biguint(&bit_scalar); + let bit = self.biguint_to_nonnative::(&bit_biguint); + let not_bit = self.sub_nonnative(&one, &bit); + + let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p); + + let result_x = result.x; + let result_y = result.y; + let result_plus_2_i_p_x = result_plus_2_i_p.x; + let result_plus_2_i_p_y = result_plus_2_i_p.y; + + let new_x_if_bit = self.mul_nonnative(&bit, &result_plus_2_i_p_x); + let new_x_if_not_bit = self.mul_nonnative(¬_bit, &result_x); + let new_y_if_bit = self.mul_nonnative(&bit, &result_plus_2_i_p_y); + let new_y_if_not_bit = self.mul_nonnative(¬_bit, &result_y); + + let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit); + let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit); + + result = AffinePointTarget { x: new_x, y: new_y }; + + two_i_times_p = self.curve_double(&two_i_times_p); + cur_n = new_n; + } + + // Subtract off result's intial value of p. + let neg_p = self.curve_neg(&p); + result = self.curve_add(&result, &neg_p); + + result + } } mod tests { @@ -200,19 +319,53 @@ mod tests { let mut builder = CircuitBuilder::::new(config); let g = Secp256K1::GENERATOR_AFFINE; - let neg_g = g.neg(); let g_target = builder.constant_affine_point(g); let neg_g_target = builder.curve_neg(&g_target); let double_g = g.double(); - let double_g_other_target = builder.constant_affine_point(double_g); - builder.curve_assert_valid(&double_g_other_target); + let double_g_expected = builder.constant_affine_point(double_g); + builder.curve_assert_valid(&double_g_expected); - let double_g_target = builder.curve_double(&g_target, g); - let double_neg_g_target = builder.curve_double(&neg_g_target, neg_g); + let double_neg_g = (-g).double(); + let double_neg_g_expected = builder.constant_affine_point(double_neg_g); + builder.curve_assert_valid(&double_neg_g_expected); - builder.curve_assert_valid(&double_g_target); - builder.curve_assert_valid(&double_neg_g_target); + let double_g_actual = builder.curve_double(&g_target); + let double_neg_g_actual = builder.curve_double(&neg_g_target); + builder.curve_assert_valid(&double_g_actual); + builder.curve_assert_valid(&double_neg_g_actual); + + builder.connect_affine_point(&double_g_expected, &double_g_actual); + builder.connect_affine_point(&double_neg_g_expected, &double_neg_g_actual); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_curve_add() -> Result<()> { + type F = GoldilocksField; + const D: usize = 4; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let double_g = g.double(); + let g_plus_2g = (g + double_g).to_affine(); + let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); + builder.curve_assert_valid(&g_plus_2g_expected); + + let g_target = builder.constant_affine_point(g); + let double_g_target = builder.curve_double(&g_target); + let g_plus_2g_actual = builder.curve_add(&g_target, &double_g_target); + builder.curve_assert_valid(&g_plus_2g_actual); + + builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); let data = builder.build(); let proof = data.prove(pw).unwrap(); diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 90735a61..19d86658 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -17,6 +17,14 @@ pub struct NonNativeTarget { } impl, const D: usize> CircuitBuilder { + fn num_nonnative_limbs() -> usize { + let ff_size = FF::order(); + let f_size = F::order(); + let num_limbs = ((ff_size + f_size.clone() - BigUint::one()) / f_size).to_u32_digits()[0]; + + num_limbs as usize + } + pub fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget { NonNativeTarget { value: x.clone(), @@ -42,6 +50,16 @@ impl, const D: usize> CircuitBuilder { self.connect_biguint(&lhs.value, &rhs.value); } + pub fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget { + let num_limbs = Self::num_nonnative_limbs::(); + let value = self.add_virtual_biguint_target(num_limbs); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + // Add two `NonNativeTarget`s. pub fn add_nonnative( &mut self, @@ -106,6 +124,20 @@ impl, const D: usize> CircuitBuilder { inv } + pub fn div_rem_nonnative( + &mut self, + x: &NonNativeTarget, + y: &NonNativeTarget, + ) -> (NonNativeTarget, NonNativeTarget) { + let x_biguint = self.nonnative_to_biguint(x); + let y_biguint = self.nonnative_to_biguint(y); + + let (div_biguint, rem_biguint) = self.div_rem_biguint(&x_biguint, &y_biguint); + let div = self.biguint_to_nonnative(&div_biguint); + let rem = self.biguint_to_nonnative(&rem_biguint); + (div, rem) + } + /// Returns `x % |FF|` as a `NonNativeTarget`. fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget { let modulus = FF::order(); From 2ec3ea8634e7f25d9c5d3f70f5511c36048d663c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 18 Nov 2021 15:48:28 -0800 Subject: [PATCH 169/202] new curve_mul --- src/gadgets/curve.rs | 52 +++++++++++++++++++++++++++++++--------- src/gadgets/nonnative.rs | 35 ++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 12 deletions(-) diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index eda0e5e0..0982d5f9 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -201,18 +201,18 @@ impl, const D: usize> CircuitBuilder { let two = self.constant_nonnative(C::ScalarField::TWO); let num_bits = C::ScalarField::BITS; + let bits = self.split_nonnative_to_bits(&n); + let bits_as_base: Vec> = + bits.iter().map(|b| self.bool_to_nonnative(b)).collect(); + // Result starts at p, which is later subtracted, because we don't support arithmetic with the zero point. let mut result = self.add_virtual_affine_point_target(); self.connect_affine_point(p, &result); let mut two_i_times_p = self.add_virtual_affine_point_target(); self.connect_affine_point(p, &two_i_times_p); - let mut cur_n = self.add_virtual_nonnative_target::(); - for _i in 0..num_bits { - let (bit_scalar, new_n) = self.div_rem_nonnative(&cur_n, &two); - let bit_biguint = self.nonnative_to_biguint(&bit_scalar); - let bit = self.biguint_to_nonnative::(&bit_biguint); - let not_bit = self.sub_nonnative(&one, &bit); + for bit in bits_as_base.iter() { + let not_bit = self.sub_nonnative(&one, bit); let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p); @@ -221,9 +221,9 @@ impl, const D: usize> CircuitBuilder { let result_plus_2_i_p_x = result_plus_2_i_p.x; let result_plus_2_i_p_y = result_plus_2_i_p.y; - let new_x_if_bit = self.mul_nonnative(&bit, &result_plus_2_i_p_x); + let new_x_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p_x); let new_x_if_not_bit = self.mul_nonnative(¬_bit, &result_x); - let new_y_if_bit = self.mul_nonnative(&bit, &result_plus_2_i_p_y); + let new_y_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p_y); let new_y_if_not_bit = self.mul_nonnative(¬_bit, &result_y); let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit); @@ -232,7 +232,6 @@ impl, const D: usize> CircuitBuilder { result = AffinePointTarget { x: new_x, y: new_y }; two_i_times_p = self.curve_double(&two_i_times_p); - cur_n = new_n; } // Subtract off result's intial value of p. @@ -244,15 +243,16 @@ impl, const D: usize> CircuitBuilder { } mod tests { - use std::ops::Neg; + use std::ops::{Mul, Neg}; use anyhow::Result; - use crate::curve::curve_types::{AffinePoint, Curve}; + use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; use crate::curve::secp256k1::Secp256K1; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::field::secp256k1_base::Secp256K1Base; + use crate::field::secp256k1_scalar::Secp256K1Scalar; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; @@ -372,4 +372,34 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_curve_mul() -> Result<()> { + type F = GoldilocksField; + const D: usize = 4; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let five = Secp256K1Scalar::from_canonical_usize(5); + let five_scalar = CurveScalar::(five); + let five_g = (five_scalar * g.to_projective()).to_affine(); + let five_g_expected = builder.constant_affine_point(five_g); + builder.curve_assert_valid(&five_g_expected); + + let g_target = builder.constant_affine_point(g); + let five_target = builder.constant_nonnative(five); + let five_g_actual = builder.curve_scalar_mul(&g_target, &five_target); + builder.curve_assert_valid(&five_g_actual); + + builder.connect_affine_point(&five_g_expected, &five_g_actual); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } } diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 19d86658..88250093 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -4,9 +4,10 @@ use num::{BigUint, One, Zero}; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::biguint::BigUintTarget; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -158,6 +159,38 @@ impl, const D: usize> CircuitBuilder { let x_biguint = self.nonnative_to_biguint(x); self.reduce(&x_biguint) } + + pub fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget { + let limbs = vec![U32Target(b.target)]; + let value = BigUintTarget { limbs }; + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + // Split a nonnative field element to bits. + pub fn split_nonnative_to_bits( + &mut self, + x: &NonNativeTarget, + ) -> Vec { + let num_limbs = x.value.num_limbs(); + let mut result = Vec::with_capacity(num_limbs * 32); + + for i in 0..num_limbs { + let limb = x.value.get_limb(i); + let bit_targets = self.split_le_base::<2>(limb.0, 32); + let mut bits: Vec<_> = bit_targets + .iter() + .map(|&t| BoolTarget::new_unsafe(t)) + .collect(); + + result.append(&mut bits); + } + + result + } } #[derive(Debug)] From a6ddc2ed5dc8275758cbdd18764e8e4b41204a7c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 19 Nov 2021 15:27:02 -0800 Subject: [PATCH 170/202] curve_mul testing --- src/gadgets/biguint.rs | 7 ++++++- src/gadgets/curve.rs | 11 ++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index b67c85a5..9e14cdb7 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -173,7 +173,12 @@ impl, const D: usize> CircuitBuilder { ) -> (BigUintTarget, BigUintTarget) { let a_len = a.limbs.len(); let b_len = b.limbs.len(); - let div = self.add_virtual_biguint_target(a_len - b_len + 1); + let div_num_limbs = if b_len > a_len + 1 { + 0 + } else { + a_len - b_len + 1 + }; + let div = self.add_virtual_biguint_target(div_num_limbs); let rem = self.add_virtual_biguint_target(b_len); self.add_simple_generator(BigUintDivRemGenerator:: { diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index 0982d5f9..37d99997 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -198,8 +198,6 @@ impl, const D: usize> CircuitBuilder { n: &NonNativeTarget, ) -> AffinePointTarget { let one = self.constant_nonnative(C::BaseField::ONE); - let two = self.constant_nonnative(C::ScalarField::TWO); - let num_bits = C::ScalarField::BITS; let bits = self.split_nonnative_to_bits(&n); let bits_as_base: Vec> = @@ -378,7 +376,10 @@ mod tests { type F = GoldilocksField; const D: usize = 4; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig { + num_routed_wires: 33, + ..CircuitConfig::standard_recursion_config() + }; let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -393,9 +394,9 @@ mod tests { let g_target = builder.constant_affine_point(g); let five_target = builder.constant_nonnative(five); let five_g_actual = builder.curve_scalar_mul(&g_target, &five_target); - builder.curve_assert_valid(&five_g_actual); + /*builder.curve_assert_valid(&five_g_actual); - builder.connect_affine_point(&five_g_expected, &five_g_actual); + builder.connect_affine_point(&five_g_expected, &five_g_actual);*/ let data = builder.build(); let proof = data.prove(pw).unwrap(); From 5029f87b80ee5ff769f4e84a6fa42592f06dec58 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 19 Nov 2021 15:45:42 -0800 Subject: [PATCH 171/202] fixes --- src/gadgets/nonnative.rs | 5 +---- src/plonk/circuit_builder.rs | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 88250093..b04b5c1f 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -152,10 +152,7 @@ impl, const D: usize> CircuitBuilder { } #[allow(dead_code)] - fn reduce_nonnative( - &mut self, - x: &NonNativeTarget, - ) -> NonNativeTarget { + fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { let x_biguint = self.nonnative_to_biguint(x); self.reduce(&x_biguint) } diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 693abe0a..3cb62cf3 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -15,7 +15,7 @@ use crate::gadgets::arithmetic::BaseArithmeticOperation; use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation; use crate::gadgets::arithmetic_u32::U32Target; use crate::gates::arithmetic_base::ArithmeticGate; -use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; From b1bbe30dac23187434e1ea6b0bc456c77b015ab4 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 23 Nov 2021 18:16:38 -0800 Subject: [PATCH 172/202] Fixed tests -- thanks William! --- src/gadgets/curve.rs | 54 ++++++++-------------------------------- src/gadgets/nonnative.rs | 7 ++---- 2 files changed, 13 insertions(+), 48 deletions(-) diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index 37d99997..fe2e186a 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -1,4 +1,4 @@ -use crate::curve::curve_types::{AffinePoint, Curve}; +use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::gadgets::nonnative::NonNativeTarget; @@ -157,41 +157,6 @@ impl, const D: usize> CircuitBuilder { MulPrecomputationTarget { powers } } - /*fn to_digits(&mut self, x: &NonNativeTarget) -> Vec> { - debug_assert!( - 64 % WINDOW_BITS == 0, - "For simplicity, only power-of-two window sizes are handled for now" - ); - - let base = self.constant_nonnative(C::ScalarField::from_canonical_u64(BASE as u64)); - - let num_digits = digits_per_scalar::(); - let mut digits = Vec::with_capacity(num_digits); - - let (rest, limb) = self.div_rem_nonnative(&x, &base); - for _ in 0..num_digits { - digits.push(limb); - - let (rest, limb) = self.div_rem_nonnative(&rest, &base); - } - - digits - } - - pub fn mul_with_precomputation( - &mut self, - p: &AffinePointTarget, - n: &NonNativeTarget, - precomputation: MulPrecomputationTarget, - ) -> AffinePointTarget { - // Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf - let precomputed_powers = precomputation.powers; - - let digits = self.to_digits(n); - - - }*/ - pub fn curve_scalar_mul( &mut self, p: &AffinePointTarget, @@ -203,9 +168,12 @@ impl, const D: usize> CircuitBuilder { let bits_as_base: Vec> = bits.iter().map(|b| self.bool_to_nonnative(b)).collect(); - // Result starts at p, which is later subtracted, because we don't support arithmetic with the zero point. + let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); + let randot = self.constant_affine_point(rando); + // Result starts at `rando`, which is later subtracted, because we don't support arithmetic with the zero point. let mut result = self.add_virtual_affine_point_target(); - self.connect_affine_point(p, &result); + self.connect_affine_point(&randot, &result); + let mut two_i_times_p = self.add_virtual_affine_point_target(); self.connect_affine_point(p, &two_i_times_p); @@ -232,9 +200,9 @@ impl, const D: usize> CircuitBuilder { two_i_times_p = self.curve_double(&two_i_times_p); } - // Subtract off result's intial value of p. - let neg_p = self.curve_neg(&p); - result = self.curve_add(&result, &neg_p); + // Subtract off result's intial value of `rando`. + let neg_r = self.curve_neg(&randot); + result = self.curve_add(&result, &neg_r); result } @@ -394,9 +362,9 @@ mod tests { let g_target = builder.constant_affine_point(g); let five_target = builder.constant_nonnative(five); let five_g_actual = builder.curve_scalar_mul(&g_target, &five_target); - /*builder.curve_assert_valid(&five_g_actual); + builder.curve_assert_valid(&five_g_actual); - builder.connect_affine_point(&five_g_expected, &five_g_actual);*/ + builder.connect_affine_point(&five_g_expected, &five_g_actual); let data = builder.build(); let proof = data.prove(pw).unwrap(); diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index b04b5c1f..9ee50f3a 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -10,6 +10,7 @@ use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::ceil_div_usize; #[derive(Clone, Debug)] pub struct NonNativeTarget { @@ -19,11 +20,7 @@ pub struct NonNativeTarget { impl, const D: usize> CircuitBuilder { fn num_nonnative_limbs() -> usize { - let ff_size = FF::order(); - let f_size = F::order(); - let num_limbs = ((ff_size + f_size.clone() - BigUint::one()) / f_size).to_u32_digits()[0]; - - num_limbs as usize + ceil_div_usize(FF::BITS, 32) } pub fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget { From 39300bcf0142d9e45721c5ef9a27ffb68b68b2af Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 30 Nov 2021 15:00:12 -0800 Subject: [PATCH 173/202] fixed Secp256K1Scalar --- src/field/field_testing.rs | 13 +++++++ src/field/secp256k1_base.rs | 7 ++++ src/field/secp256k1_scalar.rs | 11 +++++- src/gadgets/curve.rs | 71 ++++++++++++++++------------------- 4 files changed, 61 insertions(+), 41 deletions(-) diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index 767a3cf2..b4ee0595 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -84,6 +84,19 @@ macro_rules! test_field_arithmetic { assert_eq!(base.exp_biguint(&pow), base.exp_biguint(&big_pow)); assert_ne!(base.exp_biguint(&pow), base.exp_biguint(&big_pow_wrong)); } + + #[test] + fn inverses() { + type F = $field; + + let x = F::rand(); + let x1 = x.inverse(); + let x2 = x1.inverse(); + let x3 = x2.inverse(); + + assert_eq!(x, x2); + assert_eq!(x1, x3); + } } }; } diff --git a/src/field/secp256k1_base.rs b/src/field/secp256k1_base.rs index acb1df4e..b3fb0148 100644 --- a/src/field/secp256k1_base.rs +++ b/src/field/secp256k1_base.rs @@ -241,3 +241,10 @@ impl DivAssign for Secp256K1Base { *self = *self / rhs; } } + +#[cfg(test)] +mod tests { + use crate::test_field_arithmetic; + + test_field_arithmetic!(crate::field::secp256k1_base::Secp256K1Base); +} diff --git a/src/field/secp256k1_scalar.rs b/src/field/secp256k1_scalar.rs index 0c406b86..f4f2e6ab 100644 --- a/src/field/secp256k1_scalar.rs +++ b/src/field/secp256k1_scalar.rs @@ -80,7 +80,7 @@ impl Field for Secp256K1Scalar { const NEG_ONE: Self = Self([ 0xBFD25E8CD0364140, 0xBAAEDCE6AF48A03B, - 0xFFFFFFFFFFFFFC2F, + 0xFFFFFFFFFFFFFFFE, 0xFFFFFFFFFFFFFFFF, ]); @@ -105,7 +105,7 @@ impl Field for Secp256K1Scalar { fn order() -> BigUint { BigUint::from_slice(&[ - 0xD0364141, 0xBFD25E8C, 0xAF48A03B, 0xBAAEDCE6, 0xFFFFFC2F, 0xFFFFFFFF, 0xFFFFFFFF, + 0xD0364141, 0xBFD25E8C, 0xAF48A03B, 0xBAAEDCE6, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, ]) } @@ -251,3 +251,10 @@ impl DivAssign for Secp256K1Scalar { *self = *self / rhs; } } + +#[cfg(test)] +mod tests { + use crate::test_field_arithmetic; + + test_field_arithmetic!(crate::field::secp256k1_scalar::Secp256K1Scalar); +} diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index fe2e186a..e5602c5f 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -18,17 +18,6 @@ impl AffinePointTarget { } } -const WINDOW_BITS: usize = 4; -const BASE: usize = 1 << WINDOW_BITS; - -fn digits_per_scalar() -> usize { - (C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS -} - -pub struct MulPrecomputationTarget { - powers: Vec>, -} - impl, const D: usize> CircuitBuilder { pub fn constant_affine_point( &mut self, @@ -138,25 +127,6 @@ impl, const D: usize> CircuitBuilder { } } - pub fn mul_precompute( - &mut self, - p: &AffinePointTarget, - ) -> MulPrecomputationTarget { - let num_digits = digits_per_scalar::(); - - let mut powers = Vec::with_capacity(num_digits); - powers.push(p.clone()); - for i in 1..num_digits { - let mut power_i = powers[i - 1].clone(); - for _j in 0..WINDOW_BITS { - power_i = self.curve_double(&power_i); - } - powers.push(power_i); - } - - MulPrecomputationTarget { powers } - } - pub fn curve_scalar_mul( &mut self, p: &AffinePointTarget, @@ -182,15 +152,10 @@ impl, const D: usize> CircuitBuilder { let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p); - let result_x = result.x; - let result_y = result.y; - let result_plus_2_i_p_x = result_plus_2_i_p.x; - let result_plus_2_i_p_y = result_plus_2_i_p.y; - - let new_x_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p_x); - let new_x_if_not_bit = self.mul_nonnative(¬_bit, &result_x); - let new_y_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p_y); - let new_y_if_not_bit = self.mul_nonnative(¬_bit, &result_y); + let new_x_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.x); + let new_x_if_not_bit = self.mul_nonnative(¬_bit, &result.x); + let new_y_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.y); + let new_y_if_not_bit = self.mul_nonnative(¬_bit, &result.y); let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit); let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit); @@ -371,4 +336,32 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_curve_random() -> Result<()> { + type F = GoldilocksField; + const D: usize = 4; + + let config = CircuitConfig { + num_routed_wires: 33, + ..CircuitConfig::standard_recursion_config() + }; + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let rando = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let randot = builder.constant_affine_point(rando); + + let two_target = builder.constant_nonnative(Secp256K1Scalar::TWO); + let randot_doubled = builder.curve_double(&randot); + let randot_times_two = builder.curve_scalar_mul(&randot, &two_target); + builder.connect_affine_point(&randot_doubled, &randot_times_two); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } } From b9868ec782d08770eabba60b1619c9a02773f4d3 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 30 Nov 2021 15:22:06 -0800 Subject: [PATCH 174/202] multiplication using projective --- src/curve/curve_multiplication.rs | 19 +++++++++++-------- src/curve/mod.rs | 1 + src/gadgets/mod.rs | 1 + 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/curve/curve_multiplication.rs b/src/curve/curve_multiplication.rs index b09b8a0f..83c444d8 100644 --- a/src/curve/curve_multiplication.rs +++ b/src/curve/curve_multiplication.rs @@ -16,23 +16,22 @@ fn digits_per_scalar() -> usize { #[derive(Clone)] pub struct MultiplicationPrecomputation { /// [(2^w)^i] g for each i < digits_per_scalar. - powers: Vec>, + powers: Vec>, } impl ProjectivePoint { pub fn mul_precompute(&self) -> MultiplicationPrecomputation { let num_digits = digits_per_scalar::(); - let mut powers_proj = Vec::with_capacity(num_digits); - powers_proj.push(*self); + let mut powers = Vec::with_capacity(num_digits); + powers.push(*self); for i in 1..num_digits { - let mut power_i_proj = powers_proj[i - 1]; + let mut power_i = powers[i - 1]; for _j in 0..WINDOW_BITS { - power_i_proj = power_i_proj.double(); + power_i = power_i.double(); } - powers_proj.push(power_i_proj); + powers.push(power_i); } - let powers = ProjectivePoint::batch_to_affine(&powers_proj); MultiplicationPrecomputation { powers } } @@ -59,7 +58,11 @@ impl ProjectivePoint { all_summands.push(u_summands); } - let all_sums = affine_multisummation_batch_inversion(all_summands); + let all_sums: Vec> = all_summands + .iter() + .cloned() + .map(|vec| vec.iter().fold(ProjectivePoint::ZERO, |a, &b| a + b)) + .collect(); for i in 0..all_sums.len() { u = u + all_sums[i]; y = y + u; diff --git a/src/curve/mod.rs b/src/curve/mod.rs index d31e373e..8dd6f0d6 100644 --- a/src/curve/mod.rs +++ b/src/curve/mod.rs @@ -3,4 +3,5 @@ pub mod curve_msm; pub mod curve_multiplication; pub mod curve_summation; pub mod curve_types; +pub mod ecdsa; pub mod secp256k1; diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 09acb9de..6bb372a3 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -3,6 +3,7 @@ pub mod arithmetic_extension; pub mod arithmetic_u32; pub mod biguint; pub mod curve; +//pub mod ecdsa; pub mod hash; pub mod insert; pub mod interpolation; From f1dc1d4446c964c8221112143bed3321021f667d Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 30 Nov 2021 15:23:57 -0800 Subject: [PATCH 175/202] fix --- src/curve/mod.rs | 1 - src/gadgets/mod.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/src/curve/mod.rs b/src/curve/mod.rs index 8dd6f0d6..d31e373e 100644 --- a/src/curve/mod.rs +++ b/src/curve/mod.rs @@ -3,5 +3,4 @@ pub mod curve_msm; pub mod curve_multiplication; pub mod curve_summation; pub mod curve_types; -pub mod ecdsa; pub mod secp256k1; diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 6bb372a3..09acb9de 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -3,7 +3,6 @@ pub mod arithmetic_extension; pub mod arithmetic_u32; pub mod biguint; pub mod curve; -//pub mod ecdsa; pub mod hash; pub mod insert; pub mod interpolation; From 406092f3585967a2bde0bbbb7d64851e6c0c158f Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 30 Nov 2021 15:56:12 -0800 Subject: [PATCH 176/202] clippy fixes --- src/curve/curve_multiplication.rs | 3 +-- src/curve/curve_summation.rs | 2 +- src/curve/curve_types.rs | 2 +- src/gadgets/curve.rs | 5 ++--- src/gadgets/nonnative.rs | 4 ++-- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/curve/curve_multiplication.rs b/src/curve/curve_multiplication.rs index 83c444d8..eb5bade1 100644 --- a/src/curve/curve_multiplication.rs +++ b/src/curve/curve_multiplication.rs @@ -1,7 +1,6 @@ use std::ops::Mul; -use crate::curve::curve_summation::affine_multisummation_batch_inversion; -use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar, ProjectivePoint}; +use crate::curve::curve_types::{Curve, CurveScalar, ProjectivePoint}; use crate::field::field_types::Field; const WINDOW_BITS: usize = 4; diff --git a/src/curve/curve_summation.rs b/src/curve/curve_summation.rs index 8f347eda..c67bc026 100644 --- a/src/curve/curve_summation.rs +++ b/src/curve/curve_summation.rs @@ -152,7 +152,7 @@ pub fn affine_multisummation_batch_inversion( // This is the doubling case. let mut numerator = x1.square().triple(); if C::A.is_nonzero() { - numerator = numerator + C::A; + numerator += C::A; } let quotient = numerator * inverse; let x3 = quotient.square() - x1.double(); diff --git a/src/curve/curve_types.rs b/src/curve/curve_types.rs index c9a04ab2..ef1f6186 100644 --- a/src/curve/curve_types.rs +++ b/src/curve/curve_types.rs @@ -183,7 +183,7 @@ impl ProjectivePoint { let zz = z.square(); let mut w = xx.triple(); if C::A.is_nonzero() { - w = w + C::A * zz; + w += C::A * zz; } let s = y.double() * z; let r = y * s; diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index e5602c5f..f7c5eaaf 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -134,7 +134,7 @@ impl, const D: usize> CircuitBuilder { ) -> AffinePointTarget { let one = self.constant_nonnative(C::BaseField::ONE); - let bits = self.split_nonnative_to_bits(&n); + let bits = self.split_nonnative_to_bits(n); let bits_as_base: Vec> = bits.iter().map(|b| self.bool_to_nonnative(b)).collect(); @@ -173,9 +173,8 @@ impl, const D: usize> CircuitBuilder { } } +#[cfg(test)] mod tests { - use std::ops::{Mul, Neg}; - use anyhow::Result; use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 9ee50f3a..56d717e3 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use num::{BigUint, One, Zero}; +use num::{BigUint, Zero}; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; @@ -115,7 +115,7 @@ impl, const D: usize> CircuitBuilder { _phantom: PhantomData, }); - let product = self.mul_nonnative(&x, &inv); + let product = self.mul_nonnative(x, &inv); let one = self.constant_nonnative(FF::ONE); self.connect_nonnative(&product, &one); From 5aa5cc9c6559ce16f553c2886999e02f89feac93 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 1 Dec 2021 09:28:00 -0800 Subject: [PATCH 177/202] ignore huge tests --- src/gadgets/curve.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index f7c5eaaf..c86c3c0d 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -304,6 +304,7 @@ mod tests { } #[test] + #[ignore] fn test_curve_mul() -> Result<()> { type F = GoldilocksField; const D: usize = 4; @@ -337,6 +338,7 @@ mod tests { } #[test] + #[ignore] fn test_curve_random() -> Result<()> { type F = GoldilocksField; const D: usize = 4; From 9d8a5fc01e0f4e6bf1573bd5c40146ff89cd6dd1 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 1 Dec 2021 09:28:31 -0800 Subject: [PATCH 178/202] removed outdated comment --- src/curve/secp256k1.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/curve/secp256k1.rs b/src/curve/secp256k1.rs index 47c6ebb2..191343a9 100644 --- a/src/curve/secp256k1.rs +++ b/src/curve/secp256k1.rs @@ -3,9 +3,6 @@ use crate::field::field_types::Field; use crate::field::secp256k1_base::Secp256K1Base; use crate::field::secp256k1_scalar::Secp256K1Scalar; -// Parameters taken from the implementation of Bls12-377 in Zexe found here: -// https://github.com/scipr-lab/zexe/blob/master/algebra/src/curves/bls12_377/g1.rs - #[derive(Debug, Copy, Clone)] pub struct Secp256K1; From 12defa80f4ce62b4f05be905b8d6922621a499b1 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 1 Dec 2021 09:28:47 -0800 Subject: [PATCH 179/202] remove unused test --- src/curve/secp256k1.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/curve/secp256k1.rs b/src/curve/secp256k1.rs index 191343a9..58472eb4 100644 --- a/src/curve/secp256k1.rs +++ b/src/curve/secp256k1.rs @@ -57,16 +57,6 @@ mod tests { assert!(neg_g.is_valid()); } - /*#[test] - fn test_double_affine() { - for i in 0..100 { - //let p = blake_hash_usize_to_curve::(i); - assert_eq!( - p.double(), - p.to_projective().double().to_affine()); - } - }*/ - #[test] fn test_naive_multiplication() { let g = Secp256K1::GENERATOR_PROJECTIVE; From 6df251e14412295b2463462f3f9bf5c220105713 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Thu, 2 Dec 2021 00:01:24 -0800 Subject: [PATCH 180/202] Remove `Singleton` type and make every `Field` a `PackedField` (#379) * Remove `Singleton` type and make every `Field` a `PackedField` * Minor: Clippy --- src/field/fft.rs | 4 +- src/field/packable.rs | 6 +- src/field/packed_field.rs | 137 ++++---------------------------------- 3 files changed, 19 insertions(+), 128 deletions(-) diff --git a/src/field/fft.rs b/src/field/fft.rs index 6f5155a4..09672278 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -5,7 +5,7 @@ use unroll::unroll_for_loops; use crate::field::field_types::Field; use crate::field::packable::Packable; -use crate::field::packed_field::{PackedField, Singleton}; +use crate::field::packed_field::PackedField; use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::{log2_strict, reverse_index_bits}; @@ -201,7 +201,7 @@ pub(crate) fn fft_classic(input: &[F], r: usize, root_table: &FftRootT if lg_n <= lg_packed_width { // Need the slice to be at least the width of two packed vectors for the vectorized version // to work. Do this tiny problem in scalar. - fft_classic_simd::>(&mut values[..], r, lg_n, root_table); + fft_classic_simd::(&mut values[..], r, lg_n, root_table); } else { fft_classic_simd::<::PackedType>(&mut values[..], r, lg_n, root_table); } diff --git a/src/field/packable.rs b/src/field/packable.rs index 94a9c056..e5fc2ac5 100644 --- a/src/field/packable.rs +++ b/src/field/packable.rs @@ -1,15 +1,15 @@ use crate::field::field_types::Field; -use crate::field::packed_field::{PackedField, Singleton}; +use crate::field::packed_field::PackedField; /// Points us to the default packing for a particular field. There may me multiple choices of -/// PackedField for a particular Field (e.g. Singleton works for all fields), but this is the +/// PackedField for a particular Field (e.g. every Field is also a PackedField), but this is the /// recommended one. The recommended packing varies by target_arch and target_feature. pub trait Packable: Field { type PackedType: PackedField; } impl Packable for F { - default type PackedType = Singleton; + default type PackedType = Self; } #[cfg(target_feature = "avx2")] diff --git a/src/field/packed_field.rs b/src/field/packed_field.rs index a4b1945a..69733bca 100644 --- a/src/field/packed_field.rs +++ b/src/field/packed_field.rs @@ -1,5 +1,4 @@ -use std::fmt; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; @@ -95,143 +94,35 @@ pub trait PackedField: } } -#[derive(Copy, Clone)] -#[repr(transparent)] -pub struct Singleton(pub F); +impl PackedField for F { + type FieldType = Self; -impl Add for Singleton { - type Output = Self; - fn add(self, rhs: Self) -> Self { - Self(self.0 + rhs.0) - } -} -impl Add for Singleton { - type Output = Self; - fn add(self, rhs: F) -> Self { - self + Self::broadcast(rhs) - } -} -impl AddAssign for Singleton { - fn add_assign(&mut self, rhs: Self) { - *self = *self + rhs; - } -} -impl AddAssign for Singleton { - fn add_assign(&mut self, rhs: F) { - *self = *self + rhs; - } -} - -impl Debug for Singleton { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "({:?})", self.0) - } -} - -impl Default for Singleton { - fn default() -> Self { - Self::zero() - } -} - -impl Mul for Singleton { - type Output = Self; - fn mul(self, rhs: Self) -> Self { - Self(self.0 * rhs.0) - } -} -impl Mul for Singleton { - type Output = Self; - fn mul(self, rhs: F) -> Self { - self * Self::broadcast(rhs) - } -} -impl MulAssign for Singleton { - fn mul_assign(&mut self, rhs: Self) { - *self = *self * rhs; - } -} -impl MulAssign for Singleton { - fn mul_assign(&mut self, rhs: F) { - *self = *self * rhs; - } -} - -impl Neg for Singleton { - type Output = Self; - fn neg(self) -> Self { - Self(-self.0) - } -} - -impl Product for Singleton { - fn product>(iter: I) -> Self { - Self(iter.map(|x| x.0).product()) - } -} - -impl PackedField for Singleton { const LOG2_WIDTH: usize = 0; - type FieldType = F; - fn broadcast(x: F) -> Self { - Self(x) + fn broadcast(x: Self::FieldType) -> Self { + x } fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self { - Self(arr[0]) + arr[0] } - fn to_arr(&self) -> [Self::FieldType; Self::WIDTH] { - [self.0] + [*self] } fn from_slice(slice: &[Self::FieldType]) -> Self { - assert!(slice.len() == 1); - Self(slice[0]) + assert_eq!(slice.len(), 1); + slice[0] } - fn to_vec(&self) -> Vec { - vec![self.0] + vec![*self] } fn interleave(&self, other: Self, r: usize) -> (Self, Self) { - match r { - 0 => (*self, other), // This is a no-op whenever r == LOG2_WIDTH. - _ => panic!("r cannot be more than LOG2_WIDTH"), + if r == 0 { + (*self, other) + } else { + panic!("r > LOG2_WIDTH"); } } - - fn square(&self) -> Self { - Self(self.0.square()) - } -} - -impl Sub for Singleton { - type Output = Self; - fn sub(self, rhs: Self) -> Self { - Self(self.0 - rhs.0) - } -} -impl Sub for Singleton { - type Output = Self; - fn sub(self, rhs: F) -> Self { - self - Self::broadcast(rhs) - } -} -impl SubAssign for Singleton { - fn sub_assign(&mut self, rhs: Self) { - *self = *self - rhs; - } -} -impl SubAssign for Singleton { - fn sub_assign(&mut self, rhs: F) { - *self = *self - rhs; - } -} - -impl Sum for Singleton { - fn sum>(iter: I) -> Self { - Self(iter.map(|x| x.0).sum()) - } } From 93d695d33e5953738a04a94d00db00924f527ec0 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 2 Dec 2021 15:14:25 +0100 Subject: [PATCH 181/202] Variable number of U32 ops --- src/gadgets/arithmetic_u32.rs | 20 ++---- src/gates/arithmetic_u32.rs | 122 ++++++++++++++++++---------------- src/plonk/circuit_builder.rs | 31 +++------ 3 files changed, 80 insertions(+), 93 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index ce7aa121..d15df304 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -76,34 +76,26 @@ impl, const D: usize> CircuitBuilder { return result; } + let gate = U32ArithmeticGate::::new_from_config(&self.config); let (gate_index, copy) = self.find_u32_arithmetic_gate(); self.connect( - Target::wire( - gate_index, - U32ArithmeticGate::::wire_ith_multiplicand_0(copy), - ), + Target::wire(gate_index, gate.wire_ith_multiplicand_0(copy)), x.0, ); self.connect( - Target::wire( - gate_index, - U32ArithmeticGate::::wire_ith_multiplicand_1(copy), - ), + Target::wire(gate_index, gate.wire_ith_multiplicand_1(copy)), y.0, ); - self.connect( - Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(copy)), - z.0, - ); + self.connect(Target::wire(gate_index, gate.wire_ith_addend(copy)), z.0); let output_low = U32Target(Target::wire( gate_index, - U32ArithmeticGate::::wire_ith_output_low_half(copy), + gate.wire_ith_output_low_half(copy), )); let output_high = U32Target(Target::wire( gate_index, - U32ArithmeticGate::::wire_ith_output_high_half(copy), + gate.wire_ith_output_high_half(copy), )); (output_low, output_high) diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs index a5a63047..e88654df 100644 --- a/src/gates/arithmetic_u32.rs +++ b/src/gates/arithmetic_u32.rs @@ -11,43 +11,49 @@ use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Number of arithmetic operations performed by an arithmetic gate. -pub const NUM_U32_ARITHMETIC_OPS: usize = 3; - /// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand). -#[derive(Clone, Debug)] +#[derive(Copy, Clone, Debug)] pub struct U32ArithmeticGate, const D: usize> { + pub num_ops: usize, _phantom: PhantomData, } impl, const D: usize> U32ArithmeticGate { - pub fn new() -> Self { + pub fn new_from_config(config: &CircuitConfig) -> Self { Self { + num_ops: Self::num_ops(config), _phantom: PhantomData, } } - pub fn wire_ith_multiplicand_0(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 5 + Self::num_limbs(); + let routed_wires_per_op = 5; + (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) + } + + pub fn wire_ith_multiplicand_0(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i } - pub fn wire_ith_multiplicand_1(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_multiplicand_1(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 1 } - pub fn wire_ith_addend(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_addend(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 2 } - pub fn wire_ith_output_low_half(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_output_low_half(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 3 } - pub fn wire_ith_output_high_half(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_output_high_half(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 4 } @@ -58,10 +64,10 @@ impl, const D: usize> U32ArithmeticGate { 64 / Self::limb_bits() } - pub fn wire_ith_output_jth_limb(i: usize, j: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); debug_assert!(j < Self::num_limbs()); - 5 * NUM_U32_ARITHMETIC_OPS + Self::num_limbs() * i + j + 5 * self.num_ops + Self::num_limbs() * i + j } } @@ -72,15 +78,15 @@ impl, const D: usize> Gate for U32ArithmeticG fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_ARITHMETIC_OPS { - let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[Self::wire_ith_addend(i)]; + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; let computed_output = multiplicand_0 * multiplicand_1 + addend; - let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)]; + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; let base = F::Extension::from_canonical_u64(1 << 32u64); let combined_output = output_high * base + output_low; @@ -92,7 +98,7 @@ impl, const D: usize> Gate for U32ArithmeticG let midpoint = Self::num_limbs() / 2; let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let product = (0..max_limb) .map(|x| this_limb - F::Extension::from_canonical_usize(x)) @@ -114,15 +120,15 @@ impl, const D: usize> Gate for U32ArithmeticG fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_ARITHMETIC_OPS { - let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[Self::wire_ith_addend(i)]; + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; let computed_output = multiplicand_0 * multiplicand_1 + addend; - let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)]; + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; let base = F::from_canonical_u64(1 << 32u64); let combined_output = output_high * base + output_low; @@ -134,7 +140,7 @@ impl, const D: usize> Gate for U32ArithmeticG let midpoint = Self::num_limbs() / 2; let base = F::from_canonical_u64(1u64 << Self::limb_bits()); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let product = (0..max_limb) .map(|x| this_limb - F::from_canonical_usize(x)) @@ -161,15 +167,15 @@ impl, const D: usize> Gate for U32ArithmeticG ) -> Vec> { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_ARITHMETIC_OPS { - let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[Self::wire_ith_addend(i)]; + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; let computed_output = builder.mul_add_extension(multiplicand_0, multiplicand_1, addend); - let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)]; + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; let base: F::Extension = F::from_canonical_u64(1 << 32u64).into(); let base_target = builder.constant_extension(base); @@ -183,7 +189,7 @@ impl, const D: usize> Gate for U32ArithmeticG let base = builder .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let mut product = builder.one_extension(); @@ -216,10 +222,11 @@ impl, const D: usize> Gate for U32ArithmeticG gate_index: usize, _local_constants: &[F], ) -> Vec>> { - (0..NUM_U32_ARITHMETIC_OPS) + (0..self.num_ops) .map(|i| { let g: Box> = Box::new( U32ArithmeticGenerator { + gate: *self, gate_index, i, _phantom: PhantomData, @@ -232,7 +239,7 @@ impl, const D: usize> Gate for U32ArithmeticG } fn num_wires(&self) -> usize { - NUM_U32_ARITHMETIC_OPS * (5 + Self::num_limbs()) + self.num_ops * (5 + Self::num_limbs()) } fn num_constants(&self) -> usize { @@ -244,12 +251,13 @@ impl, const D: usize> Gate for U32ArithmeticG } fn num_constraints(&self) -> usize { - NUM_U32_ARITHMETIC_OPS * (3 + Self::num_limbs()) + self.num_ops * (3 + Self::num_limbs()) } } #[derive(Clone, Debug)] struct U32ArithmeticGenerator, const D: usize> { + gate: U32ArithmeticGate, gate_index: usize, i: usize, _phantom: PhantomData, @@ -262,9 +270,9 @@ impl, const D: usize> SimpleGenerator let local_target = |input| Target::wire(self.gate_index, input); vec![ - local_target(U32ArithmeticGate::::wire_ith_multiplicand_0(self.i)), - local_target(U32ArithmeticGate::::wire_ith_multiplicand_1(self.i)), - local_target(U32ArithmeticGate::::wire_ith_addend(self.i)), + local_target(self.gate.wire_ith_multiplicand_0(self.i)), + local_target(self.gate.wire_ith_multiplicand_1(self.i)), + local_target(self.gate.wire_ith_addend(self.i)), ] } @@ -276,11 +284,9 @@ impl, const D: usize> SimpleGenerator let get_local_wire = |input| witness.get_wire(local_wire(input)); - let multiplicand_0 = - get_local_wire(U32ArithmeticGate::::wire_ith_multiplicand_0(self.i)); - let multiplicand_1 = - get_local_wire(U32ArithmeticGate::::wire_ith_multiplicand_1(self.i)); - let addend = get_local_wire(U32ArithmeticGate::::wire_ith_addend(self.i)); + let multiplicand_0 = get_local_wire(self.gate.wire_ith_multiplicand_0(self.i)); + let multiplicand_1 = get_local_wire(self.gate.wire_ith_multiplicand_1(self.i)); + let addend = get_local_wire(self.gate.wire_ith_addend(self.i)); let output = multiplicand_0 * multiplicand_1 + addend; let mut output_u64 = output.to_canonical_u64(); @@ -291,10 +297,8 @@ impl, const D: usize> SimpleGenerator let output_high = F::from_canonical_u64(output_high_u64); let output_low = F::from_canonical_u64(output_low_u64); - let output_high_wire = - local_wire(U32ArithmeticGate::::wire_ith_output_high_half(self.i)); - let output_low_wire = - local_wire(U32ArithmeticGate::::wire_ith_output_low_half(self.i)); + let output_high_wire = local_wire(self.gate.wire_ith_output_high_half(self.i)); + let output_low_wire = local_wire(self.gate.wire_ith_output_low_half(self.i)); out_buffer.set_wire(output_high_wire, output_high); out_buffer.set_wire(output_low_wire, output_low); @@ -310,9 +314,7 @@ impl, const D: usize> SimpleGenerator let output_limbs_f = output_limbs_u64.map(F::from_canonical_u64); for (j, output_limb) in output_limbs_f.enumerate() { - let wire = local_wire(U32ArithmeticGate::::wire_ith_output_jth_limb( - self.i, j, - )); + let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); out_buffer.set_wire(wire, output_limb); } } @@ -328,7 +330,7 @@ mod tests { use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; - use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; + use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::hash::hash_types::HashOut; @@ -337,6 +339,7 @@ mod tests { #[test] fn low_degree() { test_low_degree::(U32ArithmeticGate:: { + num_ops: 3, _phantom: PhantomData, }) } @@ -344,6 +347,7 @@ mod tests { #[test] fn eval_fns() -> Result<()> { test_eval_fns::(U32ArithmeticGate:: { + num_ops: 3, _phantom: PhantomData, }) } @@ -353,6 +357,7 @@ mod tests { type F = GoldilocksField; type FF = QuarticExtension; const D: usize = 4; + const NUM_U32_ARITHMETIC_OPS: usize = 3; fn get_wires( multiplicands_0: Vec, @@ -410,6 +415,7 @@ mod tests { .collect(); let gate = U32ArithmeticGate:: { + num_ops: NUM_U32_ARITHMETIC_OPS, _phantom: PhantomData, }; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 3cb62cf3..7799bc38 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -16,7 +16,7 @@ use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation; use crate::gadgets::arithmetic_u32::U32Target; use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; -use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; +use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; @@ -965,14 +965,14 @@ impl, const D: usize> CircuitBuilder { pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) { let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate { None => { - let gate = U32ArithmeticGate::new(); + let gate = U32ArithmeticGate::new_from_config(&self.config); let gate_index = self.add_gate(gate, vec![]); (gate_index, 0) } Some((gate_index, copy)) => (gate_index, copy), }; - if copy == NUM_U32_ARITHMETIC_OPS - 1 { + if copy == U32ArithmeticGate::::num_ops(&self.config) - 1 { self.batched_gates.current_u32_arithmetic_gate = None; } else { self.batched_gates.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); @@ -1111,23 +1111,12 @@ impl, const D: usize> CircuitBuilder { /// Fill the remaining unused U32 arithmetic operations with zeros, so that all /// `U32ArithmeticGenerator`s are run. fn fill_u32_arithmetic_gates(&mut self) { - let zero = self.zero(); - if let Some((gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate { - for i in copy..NUM_U32_ARITHMETIC_OPS { - let wire_multiplicand_0 = Target::wire( - gate_index, - U32ArithmeticGate::::wire_ith_multiplicand_0(i), - ); - let wire_multiplicand_1 = Target::wire( - gate_index, - U32ArithmeticGate::::wire_ith_multiplicand_1(i), - ); - let wire_addend = - Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(i)); - - self.connect(zero, wire_multiplicand_0); - self.connect(zero, wire_multiplicand_1); - self.connect(zero, wire_addend); + let zero = self.zero_u32(); + if let Some((_gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate { + for _ in copy..U32ArithmeticGate::::num_ops(&self.config) { + let dummy = self.add_virtual_u32_target(); + self.mul_add_u32(dummy, dummy, dummy); + self.connect_u32(dummy, zero); } } } @@ -1137,7 +1126,7 @@ impl, const D: usize> CircuitBuilder { fn fill_u32_subtraction_gates(&mut self) { let zero = self.zero(); if let Some((gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate { - for i in copy..NUM_U32_ARITHMETIC_OPS { + for i in copy..NUM_U32_SUBTRACTION_OPS { let wire_input_x = Target::wire(gate_index, U32SubtractionGate::::wire_ith_input_x(i)); let wire_input_y = From 29ed0673f2a76b912af970e80761626283851921 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 2 Dec 2021 15:35:59 +0100 Subject: [PATCH 182/202] Variable number of U32 sub ops --- src/gadgets/arithmetic_u32.rs | 32 ++------- src/gates/subtraction_u32.rs | 121 ++++++++++++++++++---------------- src/plonk/circuit_builder.rs | 27 +++----- 3 files changed, 79 insertions(+), 101 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index d15df304..3bf6ce58 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -136,38 +136,18 @@ impl, const D: usize> CircuitBuilder { y: U32Target, borrow: U32Target, ) -> (U32Target, U32Target) { + let gate = U32SubtractionGate::::new_from_config(&self.config); let (gate_index, copy) = self.find_u32_subtraction_gate(); + self.connect(Target::wire(gate_index, gate.wire_ith_input_x(copy)), x.0); + self.connect(Target::wire(gate_index, gate.wire_ith_input_y(copy)), y.0); self.connect( - Target::wire( - gate_index, - U32SubtractionGate::::wire_ith_input_x(copy), - ), - x.0, - ); - self.connect( - Target::wire( - gate_index, - U32SubtractionGate::::wire_ith_input_y(copy), - ), - y.0, - ); - self.connect( - Target::wire( - gate_index, - U32SubtractionGate::::wire_ith_input_borrow(copy), - ), + Target::wire(gate_index, gate.wire_ith_input_borrow(copy)), borrow.0, ); - let output_result = U32Target(Target::wire( - gate_index, - U32SubtractionGate::::wire_ith_output_result(copy), - )); - let output_borrow = U32Target(Target::wire( - gate_index, - U32SubtractionGate::::wire_ith_output_borrow(copy), - )); + let output_result = U32Target(Target::wire(gate_index, gate.wire_ith_output_result(copy))); + let output_borrow = U32Target(Target::wire(gate_index, gate.wire_ith_output_borrow(copy))); (output_result, output_borrow) } diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs index 26f6302e..fc4cd646 100644 --- a/src/gates/subtraction_u32.rs +++ b/src/gates/subtraction_u32.rs @@ -9,44 +9,50 @@ use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Maximum number of subtractions operations performed by a single gate. -pub const NUM_U32_SUBTRACTION_OPS: usize = 3; - /// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns /// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked. -#[derive(Clone, Debug)] +#[derive(Copy, Clone, Debug)] pub struct U32SubtractionGate, const D: usize> { + pub num_ops: usize, _phantom: PhantomData, } impl, const D: usize> U32SubtractionGate { - pub fn new() -> Self { + pub fn new_from_config(config: &CircuitConfig) -> Self { Self { + num_ops: Self::num_ops(config), _phantom: PhantomData, } } - pub fn wire_ith_input_x(i: usize) -> usize { - debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 5 + Self::num_limbs(); + let routed_wires_per_op = 5; + (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) + } + + pub fn wire_ith_input_x(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i } - pub fn wire_ith_input_y(i: usize) -> usize { - debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + pub fn wire_ith_input_y(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 1 } - pub fn wire_ith_input_borrow(i: usize) -> usize { - debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + pub fn wire_ith_input_borrow(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 2 } - pub fn wire_ith_output_result(i: usize) -> usize { - debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + pub fn wire_ith_output_result(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 3 } - pub fn wire_ith_output_borrow(i: usize) -> usize { - debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + pub fn wire_ith_output_borrow(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 4 } @@ -58,10 +64,10 @@ impl, const D: usize> U32SubtractionGate { 32 / Self::limb_bits() } - pub fn wire_ith_output_jth_limb(i: usize, j: usize) -> usize { - debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); debug_assert!(j < Self::num_limbs()); - 5 * NUM_U32_SUBTRACTION_OPS + Self::num_limbs() * i + j + 5 * self.num_ops + Self::num_limbs() * i + j } } @@ -72,16 +78,16 @@ impl, const D: usize> Gate for U32Subtraction fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_SUBTRACTION_OPS { - let input_x = vars.local_wires[Self::wire_ith_input_x(i)]; - let input_y = vars.local_wires[Self::wire_ith_input_y(i)]; - let input_borrow = vars.local_wires[Self::wire_ith_input_borrow(i)]; + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; let result_initial = input_x - input_y - input_borrow; let base = F::Extension::from_canonical_u64(1 << 32u64); - let output_result = vars.local_wires[Self::wire_ith_output_result(i)]; - let output_borrow = vars.local_wires[Self::wire_ith_output_borrow(i)]; + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; constraints.push(output_result - (result_initial + base * output_borrow)); @@ -89,7 +95,7 @@ impl, const D: usize> Gate for U32Subtraction let mut combined_limbs = F::Extension::ZERO; let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let product = (0..max_limb) .map(|x| this_limb - F::Extension::from_canonical_usize(x)) @@ -109,16 +115,16 @@ impl, const D: usize> Gate for U32Subtraction fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_SUBTRACTION_OPS { - let input_x = vars.local_wires[Self::wire_ith_input_x(i)]; - let input_y = vars.local_wires[Self::wire_ith_input_y(i)]; - let input_borrow = vars.local_wires[Self::wire_ith_input_borrow(i)]; + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; let result_initial = input_x - input_y - input_borrow; let base = F::from_canonical_u64(1 << 32u64); - let output_result = vars.local_wires[Self::wire_ith_output_result(i)]; - let output_borrow = vars.local_wires[Self::wire_ith_output_borrow(i)]; + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; constraints.push(output_result - (result_initial + base * output_borrow)); @@ -126,7 +132,7 @@ impl, const D: usize> Gate for U32Subtraction let mut combined_limbs = F::ZERO; let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits()); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let product = (0..max_limb) .map(|x| this_limb - F::from_canonical_usize(x)) @@ -150,17 +156,17 @@ impl, const D: usize> Gate for U32Subtraction vars: EvaluationTargets, ) -> Vec> { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_SUBTRACTION_OPS { - let input_x = vars.local_wires[Self::wire_ith_input_x(i)]; - let input_y = vars.local_wires[Self::wire_ith_input_y(i)]; - let input_borrow = vars.local_wires[Self::wire_ith_input_borrow(i)]; + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; let diff = builder.sub_extension(input_x, input_y); let result_initial = builder.sub_extension(diff, input_borrow); let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32u64)); - let output_result = vars.local_wires[Self::wire_ith_output_result(i)]; - let output_borrow = vars.local_wires[Self::wire_ith_output_borrow(i)]; + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; let computed_output = builder.mul_add_extension(base, output_borrow, result_initial); constraints.push(builder.sub_extension(output_result, computed_output)); @@ -170,7 +176,7 @@ impl, const D: usize> Gate for U32Subtraction let limb_base = builder .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let mut product = builder.one_extension(); for x in 0..max_limb { @@ -199,10 +205,11 @@ impl, const D: usize> Gate for U32Subtraction gate_index: usize, _local_constants: &[F], ) -> Vec>> { - (0..NUM_U32_SUBTRACTION_OPS) + (0..self.num_ops) .map(|i| { let g: Box> = Box::new( U32SubtractionGenerator { + gate: *self, gate_index, i, _phantom: PhantomData, @@ -215,7 +222,7 @@ impl, const D: usize> Gate for U32Subtraction } fn num_wires(&self) -> usize { - NUM_U32_SUBTRACTION_OPS * (5 + Self::num_limbs()) + self.num_ops * (5 + Self::num_limbs()) } fn num_constants(&self) -> usize { @@ -227,12 +234,13 @@ impl, const D: usize> Gate for U32Subtraction } fn num_constraints(&self) -> usize { - NUM_U32_SUBTRACTION_OPS * (3 + Self::num_limbs()) + self.num_ops * (3 + Self::num_limbs()) } } #[derive(Clone, Debug)] struct U32SubtractionGenerator, const D: usize> { + gate: U32SubtractionGate, gate_index: usize, i: usize, _phantom: PhantomData, @@ -245,9 +253,9 @@ impl, const D: usize> SimpleGenerator let local_target = |input| Target::wire(self.gate_index, input); vec![ - local_target(U32SubtractionGate::::wire_ith_input_x(self.i)), - local_target(U32SubtractionGate::::wire_ith_input_y(self.i)), - local_target(U32SubtractionGate::::wire_ith_input_borrow(self.i)), + local_target(self.gate.wire_ith_input_x(self.i)), + local_target(self.gate.wire_ith_input_y(self.i)), + local_target(self.gate.wire_ith_input_borrow(self.i)), ] } @@ -259,10 +267,9 @@ impl, const D: usize> SimpleGenerator let get_local_wire = |input| witness.get_wire(local_wire(input)); - let input_x = get_local_wire(U32SubtractionGate::::wire_ith_input_x(self.i)); - let input_y = get_local_wire(U32SubtractionGate::::wire_ith_input_y(self.i)); - let input_borrow = - get_local_wire(U32SubtractionGate::::wire_ith_input_borrow(self.i)); + let input_x = get_local_wire(self.gate.wire_ith_input_x(self.i)); + let input_y = get_local_wire(self.gate.wire_ith_input_y(self.i)); + let input_borrow = get_local_wire(self.gate.wire_ith_input_borrow(self.i)); let result_initial = input_x - input_y - input_borrow; let result_initial_u64 = result_initial.to_canonical_u64(); @@ -275,10 +282,8 @@ impl, const D: usize> SimpleGenerator let base = F::from_canonical_u64(1 << 32u64); let output_result = result_initial + base * output_borrow; - let output_result_wire = - local_wire(U32SubtractionGate::::wire_ith_output_result(self.i)); - let output_borrow_wire = - local_wire(U32SubtractionGate::::wire_ith_output_borrow(self.i)); + let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i)); + let output_borrow_wire = local_wire(self.gate.wire_ith_output_borrow(self.i)); out_buffer.set_wire(output_result_wire, output_result); out_buffer.set_wire(output_borrow_wire, output_borrow); @@ -296,9 +301,7 @@ impl, const D: usize> SimpleGenerator .collect(); for j in 0..num_limbs { - let wire = local_wire(U32SubtractionGate::::wire_ith_output_jth_limb( - self.i, j, - )); + let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); out_buffer.set_wire(wire, output_limbs[j]); } } @@ -316,13 +319,14 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; - use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; + use crate::gates::subtraction_u32::U32SubtractionGate; use crate::hash::hash_types::HashOut; use crate::plonk::vars::EvaluationVars; #[test] fn low_degree() { test_low_degree::(U32SubtractionGate:: { + num_ops: 3, _phantom: PhantomData, }) } @@ -330,6 +334,7 @@ mod tests { #[test] fn eval_fns() -> Result<()> { test_eval_fns::(U32SubtractionGate:: { + num_ops: 3, _phantom: PhantomData, }) } @@ -339,6 +344,7 @@ mod tests { type F = GoldilocksField; type FF = QuarticExtension; const D: usize = 4; + const NUM_U32_SUBTRACTION_OPS: usize = 3; fn get_wires(inputs_x: Vec, inputs_y: Vec, borrows: Vec) -> Vec { let mut v0 = Vec::new(); @@ -399,6 +405,7 @@ mod tests { .collect(); let gate = U32SubtractionGate:: { + num_ops: NUM_U32_SUBTRACTION_OPS, _phantom: PhantomData, }; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 7799bc38..8b8ce1ff 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -24,7 +24,7 @@ use crate::gates::multiplication_extension::MulExtensionGate; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; use crate::gates::random_access::RandomAccessGate; -use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; +use crate::gates::subtraction_u32::U32SubtractionGate; use crate::gates::switch::SwitchGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget}; use crate::hash::hashing::hash_n_to_hash; @@ -984,14 +984,14 @@ impl, const D: usize> CircuitBuilder { pub(crate) fn find_u32_subtraction_gate(&mut self) -> (usize, usize) { let (gate_index, copy) = match self.batched_gates.current_u32_subtraction_gate { None => { - let gate = U32SubtractionGate::new(); + let gate = U32SubtractionGate::new_from_config(&self.config); let gate_index = self.add_gate(gate, vec![]); (gate_index, 0) } Some((gate_index, copy)) => (gate_index, copy), }; - if copy == NUM_U32_SUBTRACTION_OPS - 1 { + if copy == U32SubtractionGate::::num_ops(&self.config) - 1 { self.batched_gates.current_u32_subtraction_gate = None; } else { self.batched_gates.current_u32_subtraction_gate = Some((gate_index, copy + 1)); @@ -1124,21 +1124,12 @@ impl, const D: usize> CircuitBuilder { /// Fill the remaining unused U32 subtraction operations with zeros, so that all /// `U32SubtractionGenerator`s are run. fn fill_u32_subtraction_gates(&mut self) { - let zero = self.zero(); - if let Some((gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate { - for i in copy..NUM_U32_SUBTRACTION_OPS { - let wire_input_x = - Target::wire(gate_index, U32SubtractionGate::::wire_ith_input_x(i)); - let wire_input_y = - Target::wire(gate_index, U32SubtractionGate::::wire_ith_input_y(i)); - let wire_input_borrow = Target::wire( - gate_index, - U32SubtractionGate::::wire_ith_input_borrow(i), - ); - - self.connect(zero, wire_input_x); - self.connect(zero, wire_input_y); - self.connect(zero, wire_input_borrow); + let zero = self.zero_u32(); + if let Some((_gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate { + for _i in copy..U32SubtractionGate::::num_ops(&self.config) { + let dummy = self.add_virtual_u32_target(); + self.sub_u32(dummy, dummy, dummy); + self.connect_u32(dummy, zero); } } } From 817fe1e3a356fef87da102952f0a89aa75d7402c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 2 Dec 2021 16:53:25 +0100 Subject: [PATCH 183/202] Remove obsolete todos --- src/field/cosets.rs | 2 -- src/gadgets/arithmetic.rs | 1 - src/plonk/prover.rs | 1 - 3 files changed, 4 deletions(-) diff --git a/src/field/cosets.rs b/src/field/cosets.rs index 4ad9ba38..62be67dc 100644 --- a/src/field/cosets.rs +++ b/src/field/cosets.rs @@ -31,8 +31,6 @@ mod tests { #[test] fn distinct_cosets() { - // TODO: Switch to a smaller test field so that collision rejection is likely to occur. - type F = GoldilocksField; const SUBGROUP_BITS: usize = 5; const NUM_SHIFTS: usize = 50; diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 0931cc88..63f149cd 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -184,7 +184,6 @@ impl, const D: usize> CircuitBuilder { } /// Add `n` `Target`s. - // TODO: Can be made `D` times more efficient by using all wires of an `ArithmeticExtensionGate`. pub fn add_many(&mut self, terms: &[Target]) -> Target { let terms_ext = terms .iter() diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 4f281f06..94c2fa29 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -146,7 +146,6 @@ pub(crate) fn prove, const D: usize>( .into_par_iter() .flat_map(|mut quotient_poly| { quotient_poly.trim(); - // TODO: Return Result instead of panicking. quotient_poly.pad(quotient_degree).expect( "Quotient has failed, the vanishing polynomial is not divisible by `Z_H", ); From c2ca106a29f4e778fac265872ce25b34bb7a03b4 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 2 Dec 2021 16:56:58 +0100 Subject: [PATCH 184/202] Rewrite `add_many` --- src/gadgets/arithmetic.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 63f149cd..3d54bdc3 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -185,11 +185,7 @@ impl, const D: usize> CircuitBuilder { /// Add `n` `Target`s. pub fn add_many(&mut self, terms: &[Target]) -> Target { - let terms_ext = terms - .iter() - .map(|&t| self.convert_to_ext(t)) - .collect::>(); - self.add_many_extension(&terms_ext).to_target_array()[0] + terms.iter().fold(self.zero(), |acc, &t| self.add(acc, t)) } /// Computes `x - y`. From 5eaa1ad529dc0ed877cc0d8360da2054485ce7c3 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Thu, 2 Dec 2021 16:14:47 -0800 Subject: [PATCH 185/202] Require a `PrimeField` to be its own `PrimeField` (#383) --- src/field/field_types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/field/field_types.rs b/src/field/field_types.rs index b7b9ddf4..a3affc13 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -404,7 +404,7 @@ pub trait Field: } /// A finite field of prime order less than 2^64. -pub trait PrimeField: Field { +pub trait PrimeField: Field { const ORDER: u64; /// The number of bits required to encode any field element. From aff71943c3d34dbf70109e7fea9b552af8fb02d1 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Thu, 2 Dec 2021 18:33:43 -0800 Subject: [PATCH 186/202] Minor optimizations to AVX2 multiplication (#378) * Minor optimizations to AVX2 multiplication * Typos (thx Hamish!) --- src/field/packed_avx2/common.rs | 16 +++- src/field/packed_avx2/goldilocks.rs | 10 ++- src/field/packed_avx2/packed_prime_field.rs | 95 +++++++++++---------- 3 files changed, 70 insertions(+), 51 deletions(-) diff --git a/src/field/packed_avx2/common.rs b/src/field/packed_avx2/common.rs index 97674a17..c100e6dc 100644 --- a/src/field/packed_avx2/common.rs +++ b/src/field/packed_avx2/common.rs @@ -3,7 +3,21 @@ use core::arch::x86_64::*; use crate::field::field_types::PrimeField; pub trait ReducibleAVX2: PrimeField { - unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i; + unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i; +} + +const SIGN_BIT: u64 = 1 << 63; + +#[inline] +unsafe fn sign_bit() -> __m256i { + _mm256_set1_epi64x(SIGN_BIT as i64) +} + +/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. in +/// packed_prime_field.rs). +#[inline] +pub unsafe fn shift(x: __m256i) -> __m256i { + _mm256_xor_si256(x, sign_bit()) } #[inline] diff --git a/src/field/packed_avx2/goldilocks.rs b/src/field/packed_avx2/goldilocks.rs index 2cea1767..186c8e0c 100644 --- a/src/field/packed_avx2/goldilocks.rs +++ b/src/field/packed_avx2/goldilocks.rs @@ -2,19 +2,21 @@ use core::arch::x86_64::*; use crate::field::goldilocks_field::GoldilocksField; use crate::field::packed_avx2::common::{ - add_no_canonicalize_64_64s_s, epsilon, sub_no_canonicalize_64s_64_s, ReducibleAVX2, + add_no_canonicalize_64_64s_s, epsilon, shift, sub_no_canonicalize_64s_64_s, ReducibleAVX2, }; /// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is /// similarly shifted. impl ReducibleAVX2 for GoldilocksField { #[inline] - unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i { - let (hi0, lo0_s) = x_s; + unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i { + let (hi0, lo0) = x; + let lo0_s = shift(lo0); let hi_hi0 = _mm256_srli_epi64(hi0, 32); let lo1_s = sub_no_canonicalize_64s_64_s::(lo0_s, hi_hi0); let t1 = _mm256_mul_epu32(hi0, epsilon::()); let lo2_s = add_no_canonicalize_64_64s_s::(t1, lo1_s); - lo2_s + let lo2 = shift(lo2_s); + lo2 } } diff --git a/src/field/packed_avx2/packed_prime_field.rs b/src/field/packed_avx2/packed_prime_field.rs index b892da4a..5800d0bd 100644 --- a/src/field/packed_avx2/packed_prime_field.rs +++ b/src/field/packed_avx2/packed_prime_field.rs @@ -6,7 +6,7 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use crate::field::field_types::PrimeField; use crate::field::packed_avx2::common::{ - add_no_canonicalize_64_64s_s, epsilon, field_order, ReducibleAVX2, + add_no_canonicalize_64_64s_s, epsilon, field_order, shift, ReducibleAVX2, }; use crate::field::packed_field::PackedField; @@ -211,13 +211,6 @@ impl Sum for PackedPrimeField { } } -const SIGN_BIT: u64 = 1 << 63; - -#[inline] -unsafe fn sign_bit() -> __m256i { - _mm256_set1_epi64x(SIGN_BIT as i64) -} - // Resources: // 1. Intel Intrinsics Guide for explanation of each intrinsic: // https://software.intel.com/sites/landingpage/IntrinsicsGuide/ @@ -267,12 +260,6 @@ unsafe fn sign_bit() -> __m256i { // Notice that the above 3-value addition still only requires two calls to shift, just like our // 2-value addition. -/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. above). -#[inline] -unsafe fn shift(x: __m256i) -> __m256i { - _mm256_xor_si256(x, sign_bit()) -} - /// Convert to canonical representation. /// The argument is assumed to be shifted by 1 << 63 (i.e. x_s = x + 1<<63, where x is the field /// value). The returned value is similarly shifted by 1 << 63 (i.e. we return y_s = y + (1<<63), @@ -311,67 +298,83 @@ unsafe fn neg(y: __m256i) -> __m256i { _mm256_sub_epi64(shift(field_order::()), canonicalize_s::(y_s)) } -/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.5x slower than the +/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.33x slower than the /// scalar instruction, but may be worth it if we want our data to live in vector registers. #[inline] -unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) { - let x_hi = _mm256_srli_epi64(x, 32); - let y_hi = _mm256_srli_epi64(y, 32); +unsafe fn mul64_64(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + // We want to move the high 32 bits to the low position. The multiplication instruction ignores + // the high 32 bits, so it's ok to just duplicate it into the low position. This duplication can + // be done on port 5; bitshifts run on ports 0 and 1, competing with multiplication. + // This instruction is only provided for 32-bit floats, not integers. Idk why Intel makes the + // distinction; the casts are free and it guarantees that the exact bit pattern is preserved. + // Using a swizzle instruction of the wrong domain (float vs int) does not increase latency + // since Haswell. + let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))); + let y_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(y))); + + // All four pairwise multiplications let mul_ll = _mm256_mul_epu32(x, y); let mul_lh = _mm256_mul_epu32(x, y_hi); let mul_hl = _mm256_mul_epu32(x_hi, y); let mul_hh = _mm256_mul_epu32(x_hi, y_hi); - let res_lo0_s = shift(mul_ll); - let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 32)); - let res_lo2_s = _mm256_add_epi32(res_lo1_s, _mm256_slli_epi64(mul_hl, 32)); + // Bignum addition + // Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow. + let mul_ll_hi = _mm256_srli_epi64::<32>(mul_ll); + let t0 = _mm256_add_epi64(mul_hl, mul_ll_hi); + // Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow. + // Also, extract high 32 bits of t0 and add to mul_hh. + let t0_lo = _mm256_and_si256(t0, _mm256_set1_epi64x(u32::MAX.into())); + let t0_hi = _mm256_srli_epi64::<32>(t0); + let t1 = _mm256_add_epi64(mul_lh, t0_lo); + let t2 = _mm256_add_epi64(mul_hh, t0_hi); + // Lastly, extract the high 32 bits of t1 and add to t2. + let t1_hi = _mm256_srli_epi64::<32>(t1); + let res_hi = _mm256_add_epi64(t2, t1_hi); - // cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on - // overflow and must be subtracted, not added. - let carry0 = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); - let carry1 = _mm256_cmpgt_epi64(res_lo1_s, res_lo2_s); + // Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high + // position). + let t1_lo = _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(t1))); + let res_lo = _mm256_blend_epi32::<0xaa>(mul_ll, t1_lo); - let res_hi0 = mul_hh; - let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 32)); - let res_hi2 = _mm256_add_epi64(res_hi1, _mm256_srli_epi64(mul_hl, 32)); - let res_hi3 = _mm256_sub_epi64(res_hi2, carry0); - let res_hi4 = _mm256_sub_epi64(res_hi3, carry1); - - (res_hi4, res_lo2_s) + (res_hi, res_lo) } /// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction. #[inline] -unsafe fn square64_s(x: __m256i) -> (__m256i, __m256i) { - let x_hi = _mm256_srli_epi64(x, 32); +unsafe fn square64(x: __m256i) -> (__m256i, __m256i) { + // Get high 32 bits of x. See comment in mul64_64_s. + let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))); + + // All pairwise multiplications. let mul_ll = _mm256_mul_epu32(x, x); let mul_lh = _mm256_mul_epu32(x, x_hi); let mul_hh = _mm256_mul_epu32(x_hi, x_hi); - let res_lo0_s = shift(mul_ll); - let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 33)); + // Bignum addition, but mul_lh is shifted by 33 bits (not 32). + let mul_ll_hi = _mm256_srli_epi64::<33>(mul_ll); + let t0 = _mm256_add_epi64(mul_lh, mul_ll_hi); + let t0_hi = _mm256_srli_epi64::<31>(t0); + let res_hi = _mm256_add_epi64(mul_hh, t0_hi); - // cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on - // overflow and must be subtracted, not added. - let carry = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); + // Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high + // position). + let mul_lh_lo = _mm256_slli_epi64::<33>(mul_lh); + let res_lo = _mm256_add_epi64(mul_ll, mul_lh_lo); - let res_hi0 = mul_hh; - let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 31)); - let res_hi2 = _mm256_sub_epi64(res_hi1, carry); - - (res_hi2, res_lo1_s) + (res_hi, res_lo) } /// Multiply two integers modulo FIELD_ORDER. #[inline] unsafe fn mul(x: __m256i, y: __m256i) -> __m256i { - shift(F::reduce128s_s(mul64_64_s(x, y))) + F::reduce128(mul64_64(x, y)) } /// Square an integer modulo FIELD_ORDER. #[inline] unsafe fn square(x: __m256i) -> __m256i { - shift(F::reduce128s_s(square64_s(x))) + F::reduce128(square64(x)) } #[inline] From 2a81ec1728964f362cde7f082accf7467048f4cd Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 3 Dec 2021 08:49:19 +0100 Subject: [PATCH 187/202] Fix recursive FRI config --- src/fri/recursive_verifier.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index 91d0580a..c2684725 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -260,7 +260,7 @@ impl, const D: usize> CircuitBuilder { common_data: &CommonCircuitData, ) -> ExtensionTarget { assert!(D > 1, "Not implemented for D=1."); - let config = self.config.clone(); + let config = &common_data.config; let degree_log = common_data.degree_bits; debug_assert_eq!( degree_log, From bb029db2a7100720ae320ff1962405de0004001f Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Fri, 3 Dec 2021 13:12:19 -0800 Subject: [PATCH 188/202] Type tweaks for packed types (#387) * PackedField tweaks * AVX2 changes * FFT fixes * tests * test fixes * Lints * Rename things for clarity * Minor interleave fixes * Minor interleave fixes the sequel * Rebase fixes * Docs * Daniel PR comments --- src/field/fft.rs | 21 ++- src/field/packable.rs | 6 +- ...ked_prime_field.rs => avx2_prime_field.rs} | 166 +++++++++++------- src/field/packed_avx2/common.rs | 2 +- src/field/packed_avx2/goldilocks.rs | 4 +- src/field/packed_avx2/mod.rs | 125 +++++++------ src/field/packed_field.rs | 115 ++++++------ 7 files changed, 255 insertions(+), 184 deletions(-) rename src/field/packed_avx2/{packed_prime_field.rs => avx2_prime_field.rs} (75%) diff --git a/src/field/fft.rs b/src/field/fft.rs index 09672278..76e0fd42 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -98,12 +98,12 @@ pub fn ifft_with_options( /// Generic FFT implementation that works with both scalar and packed inputs. #[unroll_for_loops] fn fft_classic_simd( - values: &mut [P::FieldType], + values: &mut [P::Scalar], r: usize, lg_n: usize, - root_table: &FftRootTable, + root_table: &FftRootTable, ) { - let lg_packed_width = P::LOG2_WIDTH; // 0 when P is a scalar. + let lg_packed_width = log2_strict(P::WIDTH); // 0 when P is a scalar. let packed_values = P::pack_slice_mut(values); let packed_n = packed_values.len(); debug_assert!(packed_n == 1 << (lg_n - lg_packed_width)); @@ -121,19 +121,18 @@ fn fft_classic_simd( let half_m = 1 << lg_half_m; // Set omega to root_table[lg_half_m][0..half_m] but repeated. - let mut omega_vec = P::zero().to_vec(); - for (j, omega) in omega_vec.iter_mut().enumerate() { - *omega = root_table[lg_half_m][j % half_m]; + let mut omega = P::ZERO; + for (j, omega_j) in omega.as_slice_mut().iter_mut().enumerate() { + *omega_j = root_table[lg_half_m][j % half_m]; } - let omega = P::from_slice(&omega_vec[..]); for k in (0..packed_n).step_by(2) { // We have two vectors and want to do math on pairs of adjacent elements (or for // lg_half_m > 0, pairs of adjacent blocks of elements). .interleave does the // appropriate shuffling and is its own inverse. - let (u, v) = packed_values[k].interleave(packed_values[k + 1], lg_half_m); + let (u, v) = packed_values[k].interleave(packed_values[k + 1], half_m); let t = omega * v; - (packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, lg_half_m); + (packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, half_m); } } } @@ -197,13 +196,13 @@ pub(crate) fn fft_classic(input: &[F], r: usize, root_table: &FftRootT } } - let lg_packed_width = ::PackedType::LOG2_WIDTH; + let lg_packed_width = log2_strict(::Packing::WIDTH); if lg_n <= lg_packed_width { // Need the slice to be at least the width of two packed vectors for the vectorized version // to work. Do this tiny problem in scalar. fft_classic_simd::(&mut values[..], r, lg_n, root_table); } else { - fft_classic_simd::<::PackedType>(&mut values[..], r, lg_n, root_table); + fft_classic_simd::<::Packing>(&mut values[..], r, lg_n, root_table); } values } diff --git a/src/field/packable.rs b/src/field/packable.rs index e5fc2ac5..a3f96197 100644 --- a/src/field/packable.rs +++ b/src/field/packable.rs @@ -5,14 +5,14 @@ use crate::field::packed_field::PackedField; /// PackedField for a particular Field (e.g. every Field is also a PackedField), but this is the /// recommended one. The recommended packing varies by target_arch and target_feature. pub trait Packable: Field { - type PackedType: PackedField; + type Packing: PackedField; } impl Packable for F { - default type PackedType = Self; + default type Packing = Self; } #[cfg(target_feature = "avx2")] impl Packable for crate::field::goldilocks_field::GoldilocksField { - type PackedType = crate::field::packed_avx2::PackedGoldilocksAVX2; + type Packing = crate::field::packed_avx2::PackedGoldilocksAvx2; } diff --git a/src/field/packed_avx2/packed_prime_field.rs b/src/field/packed_avx2/avx2_prime_field.rs similarity index 75% rename from src/field/packed_avx2/packed_prime_field.rs rename to src/field/packed_avx2/avx2_prime_field.rs index 5800d0bd..b42814c2 100644 --- a/src/field/packed_avx2/packed_prime_field.rs +++ b/src/field/packed_avx2/avx2_prime_field.rs @@ -2,20 +2,20 @@ use core::arch::x86_64::*; use std::fmt; use std::fmt::{Debug, Formatter}; use std::iter::{Product, Sum}; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use crate::field::field_types::PrimeField; use crate::field::packed_avx2::common::{ - add_no_canonicalize_64_64s_s, epsilon, field_order, shift, ReducibleAVX2, + add_no_canonicalize_64_64s_s, epsilon, field_order, shift, ReducibleAvx2, }; use crate::field::packed_field::PackedField; -// PackedPrimeField wraps an array of four u64s, with the new and get methods to convert that +// Avx2PrimeField wraps an array of four u64s, with the new and get methods to convert that // array to and from __m256i, which is the type we actually operate on. This indirection is a -// terrible trick to change PackedPrimeField's alignment. -// We'd like to be able to cast slices of PrimeField to slices of PackedPrimeField. Rust +// terrible trick to change Avx2PrimeField's alignment. +// We'd like to be able to cast slices of PrimeField to slices of Avx2PrimeField. Rust // aligns __m256i to 32 bytes but PrimeField has a lower alignment. That alignment extends to -// PackedPrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is +// Avx2PrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is // important for Rust not to assume 32-byte alignment, so we cannot wrap __m256i directly. // There are two versions of vectorized load/store instructions on x86: aligned (vmovaps and // friends) and unaligned (vmovups etc.). The difference between them is that aligned loads and @@ -23,12 +23,12 @@ use crate::field::packed_field::PackedField; // were faster, and although this is no longer the case, compilers prefer the aligned versions if // they know that the address is aligned. Using aligned instructions on unaligned addresses leads to // bugs that can be frustrating to diagnose. Hence, we can't have Rust assuming alignment, and -// therefore PackedPrimeField wraps [F; 4] and not __m256i. +// therefore Avx2PrimeField wraps [F; 4] and not __m256i. #[derive(Copy, Clone)] #[repr(transparent)] -pub struct PackedPrimeField(pub [F; 4]); +pub struct Avx2PrimeField(pub [F; 4]); -impl PackedPrimeField { +impl Avx2PrimeField { #[inline] fn new(x: __m256i) -> Self { let mut obj = Self([F::ZERO; 4]); @@ -45,75 +45,109 @@ impl PackedPrimeField { } } -impl Add for PackedPrimeField { +impl Add for Avx2PrimeField { type Output = Self; #[inline] fn add(self, rhs: Self) -> Self { Self::new(unsafe { add::(self.get(), rhs.get()) }) } } -impl Add for PackedPrimeField { +impl Add for Avx2PrimeField { type Output = Self; #[inline] fn add(self, rhs: F) -> Self { - self + Self::broadcast(rhs) + self + Self::from(rhs) } } -impl AddAssign for PackedPrimeField { +impl Add> for as PackedField>::Scalar { + type Output = Avx2PrimeField; + #[inline] + fn add(self, rhs: Self::Output) -> Self::Output { + Self::Output::from(self) + rhs + } +} +impl AddAssign for Avx2PrimeField { #[inline] fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } -impl AddAssign for PackedPrimeField { +impl AddAssign for Avx2PrimeField { #[inline] fn add_assign(&mut self, rhs: F) { *self = *self + rhs; } } -impl Debug for PackedPrimeField { +impl Debug for Avx2PrimeField { #[inline] fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "({:?})", self.get()) } } -impl Default for PackedPrimeField { +impl Default for Avx2PrimeField { #[inline] fn default() -> Self { - Self::zero() + Self::ZERO } } -impl Mul for PackedPrimeField { +impl Div for Avx2PrimeField { + type Output = Self; + #[inline] + fn div(self, rhs: F) -> Self { + self * rhs.inverse() + } +} +impl DivAssign for Avx2PrimeField { + #[inline] + fn div_assign(&mut self, rhs: F) { + *self *= rhs.inverse(); + } +} + +impl From for Avx2PrimeField { + fn from(x: F) -> Self { + Self([x; 4]) + } +} + +impl Mul for Avx2PrimeField { type Output = Self; #[inline] fn mul(self, rhs: Self) -> Self { Self::new(unsafe { mul::(self.get(), rhs.get()) }) } } -impl Mul for PackedPrimeField { +impl Mul for Avx2PrimeField { type Output = Self; #[inline] fn mul(self, rhs: F) -> Self { - self * Self::broadcast(rhs) + self * Self::from(rhs) } } -impl MulAssign for PackedPrimeField { +impl Mul> for as PackedField>::Scalar { + type Output = Avx2PrimeField; + #[inline] + fn mul(self, rhs: Avx2PrimeField) -> Self::Output { + Self::Output::from(self) * rhs + } +} +impl MulAssign for Avx2PrimeField { #[inline] fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } -impl MulAssign for PackedPrimeField { +impl MulAssign for Avx2PrimeField { #[inline] fn mul_assign(&mut self, rhs: F) { *self = *self * rhs; } } -impl Neg for PackedPrimeField { +impl Neg for Avx2PrimeField { type Output = Self; #[inline] fn neg(self) -> Self { @@ -121,52 +155,59 @@ impl Neg for PackedPrimeField { } } -impl Product for PackedPrimeField { +impl Product for Avx2PrimeField { #[inline] fn product>(iter: I) -> Self { - iter.reduce(|x, y| x * y).unwrap_or(Self::one()) + iter.reduce(|x, y| x * y).unwrap_or(Self::ONE) } } -impl PackedField for PackedPrimeField { - const LOG2_WIDTH: usize = 2; +unsafe impl PackedField for Avx2PrimeField { + const WIDTH: usize = 4; - type FieldType = F; + type Scalar = F; + type PackedPrimeField = Avx2PrimeField; + + const ZERO: Self = Self([F::ZERO; 4]); + const ONE: Self = Self([F::ONE; 4]); #[inline] - fn broadcast(x: F) -> Self { - Self([x; 4]) - } - - #[inline] - fn from_arr(arr: [F; Self::WIDTH]) -> Self { + fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self { Self(arr) } #[inline] - fn to_arr(&self) -> [F; Self::WIDTH] { + fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] { self.0 } #[inline] - fn from_slice(slice: &[F]) -> Self { - assert!(slice.len() == 4); - Self([slice[0], slice[1], slice[2], slice[3]]) + fn from_slice(slice: &[Self::Scalar]) -> &Self { + assert_eq!(slice.len(), Self::WIDTH); + unsafe { &*slice.as_ptr().cast() } + } + #[inline] + fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self { + assert_eq!(slice.len(), Self::WIDTH); + unsafe { &mut *slice.as_mut_ptr().cast() } + } + #[inline] + fn as_slice(&self) -> &[Self::Scalar] { + &self.0[..] + } + #[inline] + fn as_slice_mut(&mut self) -> &mut [Self::Scalar] { + &mut self.0[..] } #[inline] - fn to_vec(&self) -> Vec { - self.0.into() - } - - #[inline] - fn interleave(&self, other: Self, r: usize) -> (Self, Self) { + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) { let (v0, v1) = (self.get(), other.get()); - let (res0, res1) = match r { - 0 => unsafe { interleave0(v0, v1) }, + let (res0, res1) = match block_len { 1 => unsafe { interleave1(v0, v1) }, - 2 => (v0, v1), - _ => panic!("r cannot be more than LOG2_WIDTH"), + 2 => unsafe { interleave2(v0, v1) }, + 4 => (v0, v1), + _ => panic!("unsupported block_len"), }; (Self::new(res0), Self::new(res1)) } @@ -177,37 +218,44 @@ impl PackedField for PackedPrimeField { } } -impl Sub for PackedPrimeField { +impl Sub for Avx2PrimeField { type Output = Self; #[inline] fn sub(self, rhs: Self) -> Self { Self::new(unsafe { sub::(self.get(), rhs.get()) }) } } -impl Sub for PackedPrimeField { +impl Sub for Avx2PrimeField { type Output = Self; #[inline] fn sub(self, rhs: F) -> Self { - self - Self::broadcast(rhs) + self - Self::from(rhs) } } -impl SubAssign for PackedPrimeField { +impl Sub> for as PackedField>::Scalar { + type Output = Avx2PrimeField; + #[inline] + fn sub(self, rhs: Avx2PrimeField) -> Self::Output { + Self::Output::from(self) - rhs + } +} +impl SubAssign for Avx2PrimeField { #[inline] fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } -impl SubAssign for PackedPrimeField { +impl SubAssign for Avx2PrimeField { #[inline] fn sub_assign(&mut self, rhs: F) { *self = *self - rhs; } } -impl Sum for PackedPrimeField { +impl Sum for Avx2PrimeField { #[inline] fn sum>(iter: I) -> Self { - iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) + iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO) } } @@ -367,25 +415,25 @@ unsafe fn square64(x: __m256i) -> (__m256i, __m256i) { /// Multiply two integers modulo FIELD_ORDER. #[inline] -unsafe fn mul(x: __m256i, y: __m256i) -> __m256i { +unsafe fn mul(x: __m256i, y: __m256i) -> __m256i { F::reduce128(mul64_64(x, y)) } /// Square an integer modulo FIELD_ORDER. #[inline] -unsafe fn square(x: __m256i) -> __m256i { +unsafe fn square(x: __m256i) -> __m256i { F::reduce128(square64(x)) } #[inline] -unsafe fn interleave0(x: __m256i, y: __m256i) -> (__m256i, __m256i) { +unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) { let a = _mm256_unpacklo_epi64(x, y); let b = _mm256_unpackhi_epi64(x, y); (a, b) } #[inline] -unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) { +unsafe fn interleave2(x: __m256i, y: __m256i) -> (__m256i, __m256i) { let y_lo = _mm256_castsi256_si128(y); // This has 0 cost. // 1 places y_lo in the high half of x; 0 would place it in the lower half. diff --git a/src/field/packed_avx2/common.rs b/src/field/packed_avx2/common.rs index c100e6dc..48f9524d 100644 --- a/src/field/packed_avx2/common.rs +++ b/src/field/packed_avx2/common.rs @@ -2,7 +2,7 @@ use core::arch::x86_64::*; use crate::field::field_types::PrimeField; -pub trait ReducibleAVX2: PrimeField { +pub trait ReducibleAvx2: PrimeField { unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i; } diff --git a/src/field/packed_avx2/goldilocks.rs b/src/field/packed_avx2/goldilocks.rs index 186c8e0c..954516b8 100644 --- a/src/field/packed_avx2/goldilocks.rs +++ b/src/field/packed_avx2/goldilocks.rs @@ -2,12 +2,12 @@ use core::arch::x86_64::*; use crate::field::goldilocks_field::GoldilocksField; use crate::field::packed_avx2::common::{ - add_no_canonicalize_64_64s_s, epsilon, shift, sub_no_canonicalize_64s_64_s, ReducibleAVX2, + add_no_canonicalize_64_64s_s, epsilon, shift, sub_no_canonicalize_64s_64_s, ReducibleAvx2, }; /// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is /// similarly shifted. -impl ReducibleAVX2 for GoldilocksField { +impl ReducibleAvx2 for GoldilocksField { #[inline] unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i { let (hi0, lo0) = x; diff --git a/src/field/packed_avx2/mod.rs b/src/field/packed_avx2/mod.rs index 20eecba7..5f6294a4 100644 --- a/src/field/packed_avx2/mod.rs +++ b/src/field/packed_avx2/mod.rs @@ -1,21 +1,21 @@ +mod avx2_prime_field; mod common; mod goldilocks; -mod packed_prime_field; -use packed_prime_field::PackedPrimeField; +use avx2_prime_field::Avx2PrimeField; use crate::field::goldilocks_field::GoldilocksField; -pub type PackedGoldilocksAVX2 = PackedPrimeField; +pub type PackedGoldilocksAvx2 = Avx2PrimeField; #[cfg(test)] mod tests { use crate::field::goldilocks_field::GoldilocksField; - use crate::field::packed_avx2::common::ReducibleAVX2; - use crate::field::packed_avx2::packed_prime_field::PackedPrimeField; + use crate::field::packed_avx2::avx2_prime_field::Avx2PrimeField; + use crate::field::packed_avx2::common::ReducibleAvx2; use crate::field::packed_field::PackedField; - fn test_vals_a() -> [F; 4] { + fn test_vals_a() -> [F; 4] { [ F::from_noncanonical_u64(14479013849828404771), F::from_noncanonical_u64(9087029921428221768), @@ -23,7 +23,7 @@ mod tests { F::from_noncanonical_u64(5646033492608483824), ] } - fn test_vals_b() -> [F; 4] { + fn test_vals_b() -> [F; 4] { [ F::from_noncanonical_u64(17891926589593242302), F::from_noncanonical_u64(11009798273260028228), @@ -32,17 +32,17 @@ mod tests { ] } - fn test_add() + fn test_add() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); - let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); + let packed_b = Avx2PrimeField::::from_arr(b_arr); let packed_res = packed_a + packed_b; - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a + b); for (exp, res) in expected.zip(arr_res) { @@ -50,17 +50,17 @@ mod tests { } } - fn test_mul() + fn test_mul() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); - let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); + let packed_b = Avx2PrimeField::::from_arr(b_arr); let packed_res = packed_a * packed_b; - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a * b); for (exp, res) in expected.zip(arr_res) { @@ -68,15 +68,15 @@ mod tests { } } - fn test_square() + fn test_square() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); let packed_res = packed_a.square(); - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().map(|&a| a.square()); for (exp, res) in expected.zip(arr_res) { @@ -84,15 +84,15 @@ mod tests { } } - fn test_neg() + fn test_neg() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); let packed_res = -packed_a; - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().map(|&a| -a); for (exp, res) in expected.zip(arr_res) { @@ -100,17 +100,17 @@ mod tests { } } - fn test_sub() + fn test_sub() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); - let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); + let packed_b = Avx2PrimeField::::from_arr(b_arr); let packed_res = packed_a - packed_b; - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a - b); for (exp, res) in expected.zip(arr_res) { @@ -118,33 +118,39 @@ mod tests { } } - fn test_interleave_is_involution() + fn test_interleave_is_involution() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); - let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); + let packed_b = Avx2PrimeField::::from_arr(b_arr); { // Interleave, then deinterleave. - let (x, y) = packed_a.interleave(packed_b, 0); - let (res_a, res_b) = x.interleave(y, 0); - assert_eq!(res_a.to_arr(), a_arr); - assert_eq!(res_b.to_arr(), b_arr); - } - { let (x, y) = packed_a.interleave(packed_b, 1); let (res_a, res_b) = x.interleave(y, 1); - assert_eq!(res_a.to_arr(), a_arr); - assert_eq!(res_b.to_arr(), b_arr); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); + } + { + let (x, y) = packed_a.interleave(packed_b, 2); + let (res_a, res_b) = x.interleave(y, 2); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); + } + { + let (x, y) = packed_a.interleave(packed_b, 4); + let (res_a, res_b) = x.interleave(y, 4); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); } } - fn test_interleave() + fn test_interleave() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let in_a: [F; 4] = [ F::from_noncanonical_u64(00), @@ -158,42 +164,47 @@ mod tests { F::from_noncanonical_u64(12), F::from_noncanonical_u64(13), ]; - let int0_a: [F; 4] = [ + let int1_a: [F; 4] = [ F::from_noncanonical_u64(00), F::from_noncanonical_u64(10), F::from_noncanonical_u64(02), F::from_noncanonical_u64(12), ]; - let int0_b: [F; 4] = [ + let int1_b: [F; 4] = [ F::from_noncanonical_u64(01), F::from_noncanonical_u64(11), F::from_noncanonical_u64(03), F::from_noncanonical_u64(13), ]; - let int1_a: [F; 4] = [ + let int2_a: [F; 4] = [ F::from_noncanonical_u64(00), F::from_noncanonical_u64(01), F::from_noncanonical_u64(10), F::from_noncanonical_u64(11), ]; - let int1_b: [F; 4] = [ + let int2_b: [F; 4] = [ F::from_noncanonical_u64(02), F::from_noncanonical_u64(03), F::from_noncanonical_u64(12), F::from_noncanonical_u64(13), ]; - let packed_a = PackedPrimeField::::from_arr(in_a); - let packed_b = PackedPrimeField::::from_arr(in_b); - { - let (x0, y0) = packed_a.interleave(packed_b, 0); - assert_eq!(x0.to_arr(), int0_a); - assert_eq!(y0.to_arr(), int0_b); - } + let packed_a = Avx2PrimeField::::from_arr(in_a); + let packed_b = Avx2PrimeField::::from_arr(in_b); { let (x1, y1) = packed_a.interleave(packed_b, 1); - assert_eq!(x1.to_arr(), int1_a); - assert_eq!(y1.to_arr(), int1_b); + assert_eq!(x1.as_arr(), int1_a); + assert_eq!(y1.as_arr(), int1_b); + } + { + let (x2, y2) = packed_a.interleave(packed_b, 2); + assert_eq!(x2.as_arr(), int2_a); + assert_eq!(y2.as_arr(), int2_b); + } + { + let (x4, y4) = packed_a.interleave(packed_b, 4); + assert_eq!(x4.as_arr(), in_a); + assert_eq!(y4.as_arr(), in_b); } } diff --git a/src/field/packed_field.rs b/src/field/packed_field.rs index 69733bca..f2b0c83e 100644 --- a/src/field/packed_field.rs +++ b/src/field/packed_field.rs @@ -1,76 +1,82 @@ use std::fmt::Debug; use std::iter::{Product, Sum}; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::slice; use crate::field::field_types::Field; -pub trait PackedField: +/// # Safety +/// - WIDTH is assumed to be a power of 2. +/// - If P implements PackedField then P must be castable to/from [P::Scalar; P::WIDTH] without UB. +pub unsafe trait PackedField: 'static + Add - + Add + + Add + AddAssign - + AddAssign + + AddAssign + Copy + Debug + Default - // TODO: Implementing Div sounds like a pain so it's a worry for later. + + From + // TODO: Implement packed / packed division + + Div + Mul - + Mul + + Mul + MulAssign - + MulAssign + + MulAssign + Neg + Product + Send + Sub - + Sub + + Sub + SubAssign - + SubAssign + + SubAssign + Sum + Sync +where + Self::Scalar: Add, + Self::Scalar: Mul, + Self::Scalar: Sub, { - type FieldType: Field; + type Scalar: Field; + type PackedPrimeField: PackedField::PrimeField>; - const LOG2_WIDTH: usize; - const WIDTH: usize = 1 << Self::LOG2_WIDTH; + const WIDTH: usize; + const ZERO: Self; + const ONE: Self; fn square(&self) -> Self { *self * *self } - fn zero() -> Self { - Self::broadcast(Self::FieldType::ZERO) - } - fn one() -> Self { - Self::broadcast(Self::FieldType::ONE) - } + fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self; + fn as_arr(&self) -> [Self::Scalar; Self::WIDTH]; - fn broadcast(x: Self::FieldType) -> Self; + fn from_slice(slice: &[Self::Scalar]) -> &Self; + fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self; + fn as_slice(&self) -> &[Self::Scalar]; + fn as_slice_mut(&mut self) -> &mut [Self::Scalar]; - fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self; - fn to_arr(&self) -> [Self::FieldType; Self::WIDTH]; - - fn from_slice(slice: &[Self::FieldType]) -> Self; - fn to_vec(&self) -> Vec; - - /// Take interpret two vectors as chunks of (1 << r) elements. Unpack and interleave those + /// Take interpret two vectors as chunks of block_len elements. Unpack and interleave those /// chunks. This is best seen with an example. If we have: /// A = [x0, y0, x1, y1], /// B = [x2, y2, x3, y3], /// then - /// interleave(A, B, 0) = ([x0, x2, x1, x3], [y0, y2, y1, y3]). + /// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3]). /// Pairs that were adjacent in the input are at corresponding positions in the output. - /// r lets us set the size of chunks we're interleaving. If we set r = 1, then for + /// r lets us set the size of chunks we're interleaving. If we set block_len = 2, then for /// A = [x0, x1, y0, y1], /// B = [x2, x3, y2, y3], /// we obtain - /// interleave(A, B, r) = ([x0, x1, x2, x3], [y0, y1, y2, y3]). + /// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3]). /// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and /// transposing those matrices. - /// When r = LOG2_WIDTH, this operation is a no-op. Values of r > LOG2_WIDTH are not - /// permitted. - fn interleave(&self, other: Self, r: usize) -> (Self, Self); + /// When block_len = WIDTH, this operation is a no-op. block_len must divide WIDTH. Since + /// WIDTH is specified to be a power of 2, block_len must also be a power of 2. It cannot be 0 + /// and it cannot be > WIDTH. + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self); - fn pack_slice(buf: &[Self::FieldType]) -> &[Self] { + fn pack_slice(buf: &[Self::Scalar]) -> &[Self] { assert!( buf.len() % Self::WIDTH == 0, "Slice length (got {}) must be a multiple of packed field width ({}).", @@ -81,7 +87,7 @@ pub trait PackedField: let n = buf.len() / Self::WIDTH; unsafe { std::slice::from_raw_parts(buf_ptr, n) } } - fn pack_slice_mut(buf: &mut [Self::FieldType]) -> &mut [Self] { + fn pack_slice_mut(buf: &mut [Self::Scalar]) -> &mut [Self] { assert!( buf.len() % Self::WIDTH == 0, "Slice length (got {}) must be a multiple of packed field width ({}).", @@ -94,35 +100,42 @@ pub trait PackedField: } } -impl PackedField for F { - type FieldType = Self; +unsafe impl PackedField for F { + type Scalar = Self; + type PackedPrimeField = F::PrimeField; - const LOG2_WIDTH: usize = 0; + const WIDTH: usize = 1; + const ZERO: Self = ::ZERO; + const ONE: Self = ::ONE; - fn broadcast(x: Self::FieldType) -> Self { - x + fn square(&self) -> Self { + ::square(self) } - fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self { + fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self { arr[0] } - fn to_arr(&self) -> [Self::FieldType; Self::WIDTH] { + fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] { [*self] } - fn from_slice(slice: &[Self::FieldType]) -> Self { - assert_eq!(slice.len(), 1); - slice[0] + fn from_slice(slice: &[Self::Scalar]) -> &Self { + &slice[0] } - fn to_vec(&self) -> Vec { - vec![*self] + fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self { + &mut slice[0] + } + fn as_slice(&self) -> &[Self::Scalar] { + slice::from_ref(self) + } + fn as_slice_mut(&mut self) -> &mut [Self::Scalar] { + slice::from_mut(self) } - fn interleave(&self, other: Self, r: usize) -> (Self, Self) { - if r == 0 { - (*self, other) - } else { - panic!("r > LOG2_WIDTH"); + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) { + match block_len { + 1 => (*self, other), + _ => panic!("unsupported block length"), } } } From d6a0a2e77216f92489122a55748517a8df573a81 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Fri, 3 Dec 2021 13:23:43 -0800 Subject: [PATCH 189/202] Run CI on optimized build (#384) * Run CI on optimized build * Enable overflow checks --- .github/workflows/continuous-integration-workflow.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/continuous-integration-workflow.yml b/.github/workflows/continuous-integration-workflow.yml index 4cb70602..bf54ab3d 100644 --- a/.github/workflows/continuous-integration-workflow.yml +++ b/.github/workflows/continuous-integration-workflow.yml @@ -28,6 +28,8 @@ jobs: with: command: test args: --all + env: + RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y lints: name: Formatting and Clippy From 58e1febde7deefe99a649e8bfbec5270c5de87d4 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 6 Dec 2021 00:04:01 -0800 Subject: [PATCH 190/202] Update size-optimized recursion test (#388) I think it should start with `standard_recursion_config`, since the goal of the test is to start with a regular speed-optimized recursive proof and shrink it. The final proof is a bit larger now, mainly because of the update to 100 bits, and partly (less importantly) because it starts with the now-standard arity 16. We could maybe switch from arity 16 to 8 somewhere in the chain, but I think that might require another proof layer, and didn't want to complicate it too much. --- src/plonk/circuit_data.rs | 13 ------------- src/plonk/recursive_verifier.rs | 14 +++++++------- 2 files changed, 7 insertions(+), 20 deletions(-) diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index bf8024df..f0451131 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -76,19 +76,6 @@ impl CircuitConfig { } } - pub fn size_optimized_recursion_config() -> Self { - Self { - security_bits: 93, - cap_height: 3, - fri_config: FriConfig { - reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), - num_query_rounds: 26, - ..CircuitConfig::standard_recursion_config().fri_config - }, - ..CircuitConfig::standard_recursion_config() - } - } - pub fn standard_recursion_zk_config() -> Self { CircuitConfig { zero_knowledge: true, diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 46ca5e6a..0a03fab0 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -408,7 +408,7 @@ mod tests { type F = GoldilocksField; const D: usize = 2; - let standard_config = CircuitConfig::size_optimized_recursion_config(); + let standard_config = CircuitConfig::standard_recursion_config(); // An initial dummy proof. let (proof, vd, cd) = dummy_proof::(&standard_config, 4_000)?; @@ -432,7 +432,7 @@ mod tests { rate_bits: 7, fri_config: FriConfig { proof_of_work_bits: 16, - num_query_rounds: 11, + num_query_rounds: 12, ..standard_config.fri_config.clone() }, ..standard_config @@ -453,11 +453,11 @@ mod tests { let final_config = CircuitConfig { cap_height: 0, rate_bits: 8, - num_routed_wires: 25, + num_routed_wires: 37, fri_config: FriConfig { - proof_of_work_bits: 21, - reduction_strategy: FriReductionStrategy::MinSize(Some(3)), - num_query_rounds: 9, + proof_of_work_bits: 20, + reduction_strategy: FriReductionStrategy::MinSize(None), + num_query_rounds: 10, }, ..high_rate_config }; @@ -471,7 +471,7 @@ mod tests { true, true, )?; - assert_eq!(cd.degree_bits, 12); + assert_eq!(cd.degree_bits, 12, "final proof too large"); test_serialization(&proof, &cd)?; From dad35ae6215e6118486a0e92152586b15fab50fd Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 6 Dec 2021 16:00:22 +0100 Subject: [PATCH 191/202] Fix tests --- src/gadgets/arithmetic_extension.rs | 29 +++++++++++++++++------------ src/gadgets/biguint.rs | 27 +++++++++++++++++---------- src/gadgets/split_base.rs | 4 ++-- src/iop/witness.rs | 11 +++++++++++ 4 files changed, 47 insertions(+), 24 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index d81943ab..3262ab7d 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -558,6 +558,7 @@ mod tests { use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::extension_field::target::ExtensionAlgebraTarget; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::iop::witness::{PartialWitness, Witness}; @@ -618,9 +619,7 @@ mod tests { let yt = builder.constant_extension(y); let zt = builder.constant_extension(z); let comp_zt = builder.div_extension(xt, yt); - let comp_zt_unsafe = builder.div_extension(xt, yt); builder.connect_extension(zt, comp_zt); - builder.connect_extension(zt, comp_zt_unsafe); let data = builder.build(); let proof = data.prove(pw)?; @@ -636,23 +635,29 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); + let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = FF::rand_vec(4); - let y = FF::rand_vec(4); - let xa = ExtensionAlgebra(x.try_into().unwrap()); - let ya = ExtensionAlgebra(y.try_into().unwrap()); - let za = xa * ya; - - let xt = builder.constant_ext_algebra(xa); - let yt = builder.constant_ext_algebra(ya); - let zt = builder.constant_ext_algebra(za); + let xt = + ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap()); + let yt = + ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap()); + let zt = + ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap()); let comp_zt = builder.mul_ext_algebra(xt, yt); for i in 0..D { builder.connect_extension(zt.0[i], comp_zt.0[i]); } + let x = ExtensionAlgebra::(FF::rand_vec(D).try_into().unwrap()); + let y = ExtensionAlgebra::(FF::rand_vec(D).try_into().unwrap()); + let z = x * y; + for i in 0..D { + pw.set_extension_target(xt.0[i], x.0[i]); + pw.set_extension_target(yt.0[i], y.0[i]); + pw.set_extension_target(zt.0[i], z.0[i]); + } + let data = builder.build(); let proof = data.prove(pw)?; diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 9e14cdb7..e037c402 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -247,6 +247,7 @@ mod tests { use num::{BigUint, FromPrimitive, Integer}; use rand::Rng; + use crate::iop::witness::Witness; use crate::{ field::goldilocks_field::GoldilocksField, iop::witness::PartialWitness, @@ -263,16 +264,19 @@ mod tests { type F = GoldilocksField; let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); + let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_biguint(&x_value); - let y = builder.constant_biguint(&y_value); + let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); + let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); let z = builder.add_biguint(&x, &y); - let expected_z = builder.constant_biguint(&expected_z_value); - + let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); builder.connect_biguint(&z, &expected_z); + pw.set_biguint_target(&x, &x_value); + pw.set_biguint_target(&y, &y_value); + pw.set_biguint_target(&expected_z, &expected_z_value); + let data = builder.build(); let proof = data.prove(pw).unwrap(); verify(proof, &data.verifier_only, &data.common) @@ -316,16 +320,19 @@ mod tests { type F = GoldilocksField; let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); + let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_biguint(&x_value); - let y = builder.constant_biguint(&y_value); + let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); + let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); let z = builder.mul_biguint(&x, &y); - let expected_z = builder.constant_biguint(&expected_z_value); - + let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); builder.connect_biguint(&z, &expected_z); + pw.set_biguint_target(&x, &x_value); + pw.set_biguint_target(&y, &y_value); + pw.set_biguint_target(&expected_z, &expected_z_value); + let data = builder.build(); let proof = data.prove(pw).unwrap(); verify(proof, &data.verifier_only, &data.common) diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index ade2ab0c..59879e4d 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -146,14 +146,14 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let n = thread_rng().gen_range(0..(1 << 10)); + let n = thread_rng().gen_range(0..(1 << 30)); let x = builder.constant(F::from_canonical_usize(n)); let zero = builder._false(); let one = builder._true(); let y = builder.le_sum( - (0..10) + (0..30) .scan(n, |acc, _| { let tmp = *acc % 2; *acc /= 2; diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 8b6df90a..6ace4411 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -5,6 +5,7 @@ use num::{BigUint, FromPrimitive, Zero}; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::Field; +use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::biguint::BigUintTarget; use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::HashOutTarget; @@ -136,6 +137,16 @@ pub trait Witness { self.set_target(target.target, F::from_bool(value)) } + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)) + } + + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { + for (<, &l) in target.limbs.iter().zip(&value.to_u32_digits()) { + self.set_u32_target(lt, l); + } + } + fn set_wire(&mut self, wire: Wire, value: F) { self.set_target(Target::Wire(wire), value) } From 5061b2d11024a4c1775825e2b1b2aa0eb3fa4092 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 7 Dec 2021 08:13:39 +0100 Subject: [PATCH 192/202] Use rand_arr instead of rand_vec --- src/gadgets/arithmetic_extension.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 3262ab7d..7891f305 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -649,8 +649,8 @@ mod tests { builder.connect_extension(zt.0[i], comp_zt.0[i]); } - let x = ExtensionAlgebra::(FF::rand_vec(D).try_into().unwrap()); - let y = ExtensionAlgebra::(FF::rand_vec(D).try_into().unwrap()); + let x = ExtensionAlgebra::(FF::rand_arr()); + let y = ExtensionAlgebra::(FF::rand_arr()); let z = x * y; for i in 0..D { pw.set_extension_target(xt.0[i], x.0[i]); From 6a50c0fc4ef258669cd664b0eb7d86e031d31399 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 7 Dec 2021 08:56:27 +0100 Subject: [PATCH 193/202] Clippy --- src/gadgets/arithmetic.rs | 4 +--- src/gadgets/arithmetic_extension.rs | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 3d54bdc3..6a516399 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,7 +1,5 @@ use std::borrow::Borrow; -use itertools::Itertools; - use crate::field::extension_field::Extendable; use crate::field::field_types::{PrimeField, RichField}; use crate::gates::arithmetic_base::ArithmeticGate; @@ -206,7 +204,7 @@ impl, const D: usize> CircuitBuilder { terms .iter() .copied() - .fold1(|acc, t| self.mul(acc, t)) + .reduce(|acc, t| self.mul(acc, t)) .unwrap_or_else(|| self.one()) } diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 7891f305..7c73a09b 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -1,5 +1,3 @@ -use itertools::Itertools; - use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::FieldExtension; use crate::field::extension_field::{Extendable, OEF}; @@ -301,7 +299,7 @@ impl, const D: usize> CircuitBuilder { terms .iter() .copied() - .fold1(|acc, t| self.mul_extension(acc, t)) + .reduce(|acc, t| self.mul_extension(acc, t)) .unwrap_or_else(|| self.one_extension()) } From e6c3f354313b57e1b85e85d36635a49da235bc40 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 13 Dec 2021 14:35:05 +0100 Subject: [PATCH 194/202] working --- src/field/extension_field/quadratic.rs | 3 +-- src/field/extension_field/quartic.rs | 3 +-- src/field/field_types.rs | 26 ++++++++++++++++++-------- src/field/goldilocks_field.rs | 3 +-- src/field/packed_field.rs | 2 -- src/field/secp256k1_base.rs | 4 +--- src/field/secp256k1_scalar.rs | 4 +--- 7 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index b724095a..dfb861c2 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -50,8 +50,6 @@ impl> From for QuadraticExtension { } impl> Field for QuadraticExtension { - type PrimeField = F; - const ZERO: Self = Self([F::ZERO; 2]); const ONE: Self = Self([F::ONE, F::ZERO]); const TWO: Self = Self([F::TWO, F::ZERO]); @@ -63,6 +61,7 @@ impl> Field for QuadraticExtension { // long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as // `2(2n + 1)`, which has a 2-adicity of 1. const TWO_ADICITY: usize = F::TWO_ADICITY + 1; + const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index 0d221401..1a34d40a 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -51,8 +51,6 @@ impl> From for QuarticExtension { } impl> Field for QuarticExtension { - type PrimeField = F; - const ZERO: Self = Self([F::ZERO; 4]); const ONE: Self = Self([F::ONE, F::ZERO, F::ZERO, F::ZERO]); const TWO: Self = Self([F::TWO, F::ZERO, F::ZERO, F::ZERO]); @@ -65,6 +63,7 @@ impl> Field for QuarticExtension { // `2(2n + 1)`, which has a 2-adicity of 1. A similar argument can show that `p^2 + 1` also has // a 2-adicity of 1. const TWO_ADICITY: usize = F::TWO_ADICITY + 2; + const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); diff --git a/src/field/field_types.rs b/src/field/field_types.rs index a3affc13..dec22c9d 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -42,8 +42,6 @@ pub trait Field: + Serialize + DeserializeOwned { - type PrimeField: PrimeField; - const ZERO: Self; const ONE: Self; const TWO: Self; @@ -54,6 +52,9 @@ pub trait Field: /// The 2-adicity of this field's multiplicative group. const TWO_ADICITY: usize; + /// The 2-adicity of this field's multiplicative group. + const CHARACTERISTIC_TWO_ADICITY: usize; + /// Generator of the entire multiplicative group, i.e. all non-zero elements. const MULTIPLICATIVE_GROUP_GENERATOR: Self; /// Generator of a multiplicative subgroup of order `2^TWO_ADICITY`. @@ -212,17 +213,17 @@ pub trait Field: // TWO_ADICITY. Can remove the branch and simplify if that // saving isn't worth it. - if exp > Self::PrimeField::TWO_ADICITY { + if exp > Self::CHARACTERISTIC_TWO_ADICITY { // NB: This should be a compile-time constant let inverse_2_pow_adicity: Self = - Self::from_canonical_u64(p - ((p - 1) >> Self::PrimeField::TWO_ADICITY)); + Self::from_canonical_u64(p - ((p - 1) >> Self::CHARACTERISTIC_TWO_ADICITY)); let mut res = inverse_2_pow_adicity; - let mut e = exp - Self::PrimeField::TWO_ADICITY; + let mut e = exp - Self::CHARACTERISTIC_TWO_ADICITY; - while e > Self::PrimeField::TWO_ADICITY { + while e > Self::CHARACTERISTIC_TWO_ADICITY { res *= inverse_2_pow_adicity; - e -= Self::PrimeField::TWO_ADICITY; + e -= Self::CHARACTERISTIC_TWO_ADICITY; } res * Self::from_canonical_u64(p - ((p - 1) >> e)) } else { @@ -404,7 +405,7 @@ pub trait Field: } /// A finite field of prime order less than 2^64. -pub trait PrimeField: Field { +pub trait PrimeField: Field { const ORDER: u64; /// The number of bits required to encode any field element. @@ -449,6 +450,15 @@ pub trait PrimeField: Field { } } +pub trait SmallCharacteristicField: Field { + const SMALLCHAR: u64; + + #[inline] + fn inverse_2exp(exp: usize) -> Self { + todo!() + } +} + /// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. #[derive(Clone)] pub struct Powers { diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index 058b6db8..7bdd3c77 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -62,8 +62,6 @@ impl Debug for GoldilocksField { } impl Field for GoldilocksField { - type PrimeField = Self; - const ZERO: Self = Self(0); const ONE: Self = Self(1); const TWO: Self = Self(2); @@ -71,6 +69,7 @@ impl Field for GoldilocksField { const CHARACTERISTIC: u64 = Self::ORDER; const TWO_ADICITY: usize = 32; + const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY; // Sage: `g = GF(p).multiplicative_generator()` const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(7); diff --git a/src/field/packed_field.rs b/src/field/packed_field.rs index f2b0c83e..00b99d6c 100644 --- a/src/field/packed_field.rs +++ b/src/field/packed_field.rs @@ -39,7 +39,6 @@ where Self::Scalar: Sub, { type Scalar: Field; - type PackedPrimeField: PackedField::PrimeField>; const WIDTH: usize; const ZERO: Self; @@ -102,7 +101,6 @@ where unsafe impl PackedField for F { type Scalar = Self; - type PackedPrimeField = F::PrimeField; const WIDTH: usize = 1; const ZERO: Self = ::ZERO; diff --git a/src/field/secp256k1_base.rs b/src/field/secp256k1_base.rs index b3fb0148..32615187 100644 --- a/src/field/secp256k1_base.rs +++ b/src/field/secp256k1_base.rs @@ -68,9 +68,6 @@ impl Debug for Secp256K1Base { } impl Field for Secp256K1Base { - // TODO: fix - type PrimeField = GoldilocksField; - const ZERO: Self = Self([0; 4]); const ONE: Self = Self([1, 0, 0, 0]); const TWO: Self = Self([2, 0, 0, 0]); @@ -84,6 +81,7 @@ impl Field for Secp256K1Base { // TODO: fix const CHARACTERISTIC: u64 = 0; const TWO_ADICITY: usize = 1; + const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY; // Sage: `g = GF(p).multiplicative_generator()` const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]); diff --git a/src/field/secp256k1_scalar.rs b/src/field/secp256k1_scalar.rs index f4f2e6ab..44907b7a 100644 --- a/src/field/secp256k1_scalar.rs +++ b/src/field/secp256k1_scalar.rs @@ -71,9 +71,6 @@ impl Debug for Secp256K1Scalar { } impl Field for Secp256K1Scalar { - // TODO: fix - type PrimeField = GoldilocksField; - const ZERO: Self = Self([0; 4]); const ONE: Self = Self([1, 0, 0, 0]); const TWO: Self = Self([2, 0, 0, 0]); @@ -88,6 +85,7 @@ impl Field for Secp256K1Scalar { const CHARACTERISTIC: u64 = 0; const TWO_ADICITY: usize = 6; + const CHARACTERISTIC_TWO_ADICITY: usize = 6; // Sage: `g = GF(p).multiplicative_generator()` const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([7, 0, 0, 0]); From fb168b5d93fffd9514ca59f473a84d1537e829dc Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 13 Dec 2021 16:20:39 +0100 Subject: [PATCH 195/202] Replace characteristic with option --- src/field/extension_field/quadratic.rs | 5 +-- src/field/extension_field/quartic.rs | 5 +-- src/field/field_types.rs | 55 +++++++++++--------------- src/field/goldilocks_field.rs | 4 +- src/field/prime_field_testing.rs | 2 +- src/field/secp256k1_base.rs | 5 +-- src/field/secp256k1_scalar.rs | 6 +-- 7 files changed, 33 insertions(+), 49 deletions(-) diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index dfb861c2..16743f12 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -55,13 +55,12 @@ impl> Field for QuadraticExtension { const TWO: Self = Self([F::TWO, F::ZERO]); const NEG_ONE: Self = Self([F::NEG_ONE, F::ZERO]); - const CHARACTERISTIC: u64 = F::CHARACTERISTIC; - // `p^2 - 1 = (p - 1)(p + 1)`. The `p - 1` term has a two-adicity of `F::TWO_ADICITY`. As // long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as // `2(2n + 1)`, which has a 2-adicity of 1. const TWO_ADICITY: usize = F::TWO_ADICITY + 1; - const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY; + const CHARACTERISTIC_WITH_TWO_ADICITY: Option<(u64, usize)> = + F::CHARACTERISTIC_WITH_TWO_ADICITY; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index 1a34d40a..77329c94 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -56,14 +56,13 @@ impl> Field for QuarticExtension { const TWO: Self = Self([F::TWO, F::ZERO, F::ZERO, F::ZERO]); const NEG_ONE: Self = Self([F::NEG_ONE, F::ZERO, F::ZERO, F::ZERO]); - const CHARACTERISTIC: u64 = F::ORDER; - // `p^4 - 1 = (p - 1)(p + 1)(p^2 + 1)`. The `p - 1` term has a two-adicity of `F::TWO_ADICITY`. // As long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as // `2(2n + 1)`, which has a 2-adicity of 1. A similar argument can show that `p^2 + 1` also has // a 2-adicity of 1. const TWO_ADICITY: usize = F::TWO_ADICITY + 2; - const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY; + const CHARACTERISTIC_WITH_TWO_ADICITY: Option<(u64, usize)> = + F::CHARACTERISTIC_WITH_TWO_ADICITY; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); diff --git a/src/field/field_types.rs b/src/field/field_types.rs index dec22c9d..14426212 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -47,13 +47,12 @@ pub trait Field: const TWO: Self; const NEG_ONE: Self; - const CHARACTERISTIC: u64; - /// The 2-adicity of this field's multiplicative group. const TWO_ADICITY: usize; - /// The 2-adicity of this field's multiplicative group. - const CHARACTERISTIC_TWO_ADICITY: usize; + /// The field's characteristic and it's 2-adicity. + /// Set to `None` when the characteristic doesn't fit in a u64. + const CHARACTERISTIC_WITH_TWO_ADICITY: Option<(u64, usize)>; /// Generator of the entire multiplicative group, i.e. all non-zero elements. const MULTIPLICATIVE_GROUP_GENERATOR: Self; @@ -205,29 +204,32 @@ pub trait Field: // exp exceeds t, we repeatedly multiply by 2^-t and reduce // exp until it's in the right range. - let p = Self::CHARACTERISTIC; + if let Some((p, two_adicity)) = Self::CHARACTERISTIC_WITH_TWO_ADICITY { + // NB: The only reason this is split into two cases is to save + // the multiplication (and possible calculation of + // inverse_2_pow_adicity) in the usual case that exp <= + // TWO_ADICITY. Can remove the branch and simplify if that + // saving isn't worth it. - // NB: The only reason this is split into two cases is to save - // the multiplication (and possible calculation of - // inverse_2_pow_adicity) in the usual case that exp <= - // TWO_ADICITY. Can remove the branch and simplify if that - // saving isn't worth it. + if exp > two_adicity { + // NB: This should be a compile-time constant + let inverse_2_pow_adicity: Self = + Self::from_canonical_u64(p - ((p - 1) >> two_adicity)); - if exp > Self::CHARACTERISTIC_TWO_ADICITY { - // NB: This should be a compile-time constant - let inverse_2_pow_adicity: Self = - Self::from_canonical_u64(p - ((p - 1) >> Self::CHARACTERISTIC_TWO_ADICITY)); + let mut res = inverse_2_pow_adicity; + let mut e = exp - two_adicity; - let mut res = inverse_2_pow_adicity; - let mut e = exp - Self::CHARACTERISTIC_TWO_ADICITY; - - while e > Self::CHARACTERISTIC_TWO_ADICITY { - res *= inverse_2_pow_adicity; - e -= Self::CHARACTERISTIC_TWO_ADICITY; + while e > two_adicity { + res *= inverse_2_pow_adicity; + e -= two_adicity; + } + res * Self::from_canonical_u64(p - ((p - 1) >> e)) + } else { + Self::from_canonical_u64(p - ((p - 1) >> exp)) } - res * Self::from_canonical_u64(p - ((p - 1) >> e)) } else { - Self::from_canonical_u64(p - ((p - 1) >> exp)) + dbg!("yo"); + Self::TWO.inverse().exp_u64(exp as u64) } } @@ -450,15 +452,6 @@ pub trait PrimeField: Field { } } -pub trait SmallCharacteristicField: Field { - const SMALLCHAR: u64; - - #[inline] - fn inverse_2exp(exp: usize) -> Self { - todo!() - } -} - /// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. #[derive(Clone)] pub struct Powers { diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index 7bdd3c77..14c0a281 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -66,10 +66,10 @@ impl Field for GoldilocksField { const ONE: Self = Self(1); const TWO: Self = Self(2); const NEG_ONE: Self = Self(Self::ORDER - 1); - const CHARACTERISTIC: u64 = Self::ORDER; const TWO_ADICITY: usize = 32; - const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY; + const CHARACTERISTIC_WITH_TWO_ADICITY: Option<(u64, usize)> = + Some((Self::ORDER, Self::TWO_ADICITY)); // Sage: `g = GF(p).multiplicative_generator()` const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(7); diff --git a/src/field/prime_field_testing.rs b/src/field/prime_field_testing.rs index 9dae4896..1b7b97eb 100644 --- a/src/field/prime_field_testing.rs +++ b/src/field/prime_field_testing.rs @@ -144,7 +144,7 @@ macro_rules! test_prime_field_arithmetic { fn inverse_2exp() { type F = $field; - let v = ::PrimeField::TWO_ADICITY; + let v = ::TWO_ADICITY; for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123 * v] { let x = F::TWO.exp_u64(e as u64); diff --git a/src/field/secp256k1_base.rs b/src/field/secp256k1_base.rs index 32615187..3e0d0ef0 100644 --- a/src/field/secp256k1_base.rs +++ b/src/field/secp256k1_base.rs @@ -11,7 +11,6 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use crate::field::field_types::Field; -use crate::field::goldilocks_field::GoldilocksField; /// The base field of the secp256k1 elliptic curve. /// @@ -78,10 +77,8 @@ impl Field for Secp256K1Base { 0xFFFFFFFFFFFFFFFF, ]); - // TODO: fix - const CHARACTERISTIC: u64 = 0; const TWO_ADICITY: usize = 1; - const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY; + const CHARACTERISTIC_WITH_TWO_ADICITY: Option<(u64, usize)> = None; // Sage: `g = GF(p).multiplicative_generator()` const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]); diff --git a/src/field/secp256k1_scalar.rs b/src/field/secp256k1_scalar.rs index 44907b7a..595a27a3 100644 --- a/src/field/secp256k1_scalar.rs +++ b/src/field/secp256k1_scalar.rs @@ -12,7 +12,6 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use crate::field::field_types::Field; -use crate::field::goldilocks_field::GoldilocksField; /// The base field of the secp256k1 elliptic curve. /// @@ -81,11 +80,8 @@ impl Field for Secp256K1Scalar { 0xFFFFFFFFFFFFFFFF, ]); - // TODO: fix - const CHARACTERISTIC: u64 = 0; - const TWO_ADICITY: usize = 6; - const CHARACTERISTIC_TWO_ADICITY: usize = 6; + const CHARACTERISTIC_WITH_TWO_ADICITY: Option<(u64, usize)> = None; // Sage: `g = GF(p).multiplicative_generator()` const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([7, 0, 0, 0]); From 1d215d5d59c1004858063734975bc7f11157f198 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 13 Dec 2021 16:23:39 +0100 Subject: [PATCH 196/202] Remove dbg --- src/field/field_types.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 14426212..cce96bec 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -228,7 +228,6 @@ pub trait Field: Self::from_canonical_u64(p - ((p - 1) >> exp)) } } else { - dbg!("yo"); Self::TWO.inverse().exp_u64(exp as u64) } } From c1698bb99d59a328ea3f3efb66f4f217557e9ece Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 13 Dec 2021 16:39:07 +0100 Subject: [PATCH 197/202] Remove polynomial.rs (+clippy lints) --- src/field/fft.rs | 2 +- src/gadgets/sorting.rs | 3 +- src/polynomial/polynomial.rs | 614 ----------------------------------- 3 files changed, 2 insertions(+), 617 deletions(-) delete mode 100644 src/polynomial/polynomial.rs diff --git a/src/field/fft.rs b/src/field/fft.rs index 76e0fd42..ba94f6a7 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -43,7 +43,7 @@ fn fft_dispatch( } else { Some(fft_root_table(input.len())) }; - let used_root_table = root_table.or_else(|| computed_root_table.as_ref()).unwrap(); + let used_root_table = root_table.or(computed_root_table.as_ref()).unwrap(); fft_classic(input, zero_factor.unwrap_or(0), used_root_table) } diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 2059a888..c4378ab9 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -128,8 +128,7 @@ impl, const D: usize> SimpleGenerator fn dependencies(&self) -> Vec { self.input_ops .iter() - .map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value]) - .flatten() + .flat_map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value]) .collect() } diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs deleted file mode 100644 index 43e17823..00000000 --- a/src/polynomial/polynomial.rs +++ /dev/null @@ -1,614 +0,0 @@ -use std::cmp::max; -use std::iter::Sum; -use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; - -use anyhow::{ensure, Result}; -use serde::{Deserialize, Serialize}; - -use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::fft::{fft, fft_with_options, ifft, FftRootTable}; -use crate::field::field_types::Field; -use crate::util::log2_strict; - -/// A polynomial in point-value form. -/// -/// The points are implicitly `g^i`, where `g` generates the subgroup whose size equals the number -/// of points. -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct PolynomialValues { - pub values: Vec, -} - -impl PolynomialValues { - pub fn new(values: Vec) -> Self { - PolynomialValues { values } - } - - /// The number of values stored. - pub(crate) fn len(&self) -> usize { - self.values.len() - } - - pub fn ifft(&self) -> PolynomialCoeffs { - ifft(self) - } - - /// Returns the polynomial whose evaluation on the coset `shift*H` is `self`. - pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs { - let mut shifted_coeffs = self.ifft(); - shifted_coeffs - .coeffs - .iter_mut() - .zip(shift.inverse().powers()) - .for_each(|(c, r)| { - *c *= r; - }); - shifted_coeffs - } - - pub fn lde_multiple(polys: Vec, rate_bits: usize) -> Vec { - polys.into_iter().map(|p| p.lde(rate_bits)).collect() - } - - pub fn lde(&self, rate_bits: usize) -> Self { - let coeffs = ifft(self).lde(rate_bits); - fft_with_options(&coeffs, Some(rate_bits), None) - } - - pub fn degree(&self) -> usize { - self.degree_plus_one() - .checked_sub(1) - .expect("deg(0) is undefined") - } - - pub fn degree_plus_one(&self) -> usize { - self.ifft().degree_plus_one() - } -} - -impl From> for PolynomialValues { - fn from(values: Vec) -> Self { - Self::new(values) - } -} - -/// A polynomial in coefficient form. -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "")] -pub struct PolynomialCoeffs { - pub(crate) coeffs: Vec, -} - -impl PolynomialCoeffs { - pub fn new(coeffs: Vec) -> Self { - PolynomialCoeffs { coeffs } - } - - pub(crate) fn empty() -> Self { - Self::new(Vec::new()) - } - - pub(crate) fn zero(len: usize) -> Self { - Self::new(vec![F::ZERO; len]) - } - - pub(crate) fn is_zero(&self) -> bool { - self.coeffs.iter().all(|x| x.is_zero()) - } - - /// The number of coefficients. This does not filter out any zero coefficients, so it is not - /// necessarily related to the degree. - pub fn len(&self) -> usize { - self.coeffs.len() - } - - pub fn log_len(&self) -> usize { - log2_strict(self.len()) - } - - pub(crate) fn chunks(&self, chunk_size: usize) -> Vec { - self.coeffs - .chunks(chunk_size) - .map(|chunk| PolynomialCoeffs::new(chunk.to_vec())) - .collect() - } - - pub fn eval(&self, x: F) -> F { - self.coeffs - .iter() - .rev() - .fold(F::ZERO, |acc, &c| acc * x + c) - } - - /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. - pub fn eval_with_powers(&self, powers: &[F]) -> F { - debug_assert_eq!(self.coeffs.len(), powers.len() + 1); - let acc = self.coeffs[0]; - self.coeffs[1..] - .iter() - .zip(powers) - .fold(acc, |acc, (&x, &c)| acc + c * x) - } - - pub fn eval_base(&self, x: F::BaseField) -> F - where - F: FieldExtension, - { - self.coeffs - .iter() - .rev() - .fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c) - } - - /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. - pub fn eval_base_with_powers(&self, powers: &[F::BaseField]) -> F - where - F: FieldExtension, - { - debug_assert_eq!(self.coeffs.len(), powers.len() + 1); - let acc = self.coeffs[0]; - self.coeffs[1..] - .iter() - .zip(powers) - .fold(acc, |acc, (&x, &c)| acc + x.scalar_mul(c)) - } - - pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec { - polys.into_iter().map(|p| p.lde(rate_bits)).collect() - } - - pub fn lde(&self, rate_bits: usize) -> Self { - self.padded(self.len() << rate_bits) - } - - pub(crate) fn pad(&mut self, new_len: usize) -> Result<()> { - ensure!( - new_len >= self.len(), - "Trying to pad a polynomial of length {} to a length of {}.", - self.len(), - new_len - ); - self.coeffs.resize(new_len, F::ZERO); - Ok(()) - } - - pub(crate) fn padded(&self, new_len: usize) -> Self { - let mut poly = self.clone(); - poly.pad(new_len).unwrap(); - poly - } - - /// Removes leading zero coefficients. - pub fn trim(&mut self) { - self.coeffs.truncate(self.degree_plus_one()); - } - - /// Removes leading zero coefficients. - pub fn trimmed(&self) -> Self { - let coeffs = self.coeffs[..self.degree_plus_one()].to_vec(); - Self { coeffs } - } - - /// Degree of the polynomial + 1, or 0 for a polynomial with no non-zero coefficients. - pub(crate) fn degree_plus_one(&self) -> usize { - (0usize..self.len()) - .rev() - .find(|&i| self.coeffs[i].is_nonzero()) - .map_or(0, |i| i + 1) - } - - /// Leading coefficient. - pub fn lead(&self) -> F { - self.coeffs - .iter() - .rev() - .find(|x| x.is_nonzero()) - .map_or(F::ZERO, |x| *x) - } - - /// Reverse the order of the coefficients, not taking into account the leading zero coefficients. - pub(crate) fn rev(&self) -> Self { - Self::new(self.trimmed().coeffs.into_iter().rev().collect()) - } - - pub fn fft(&self) -> PolynomialValues { - fft(self) - } - - pub fn fft_with_options( - &self, - zero_factor: Option, - root_table: Option<&FftRootTable>, - ) -> PolynomialValues { - fft_with_options(self, zero_factor, root_table) - } - - /// Returns the evaluation of the polynomial on the coset `shift*H`. - pub fn coset_fft(&self, shift: F) -> PolynomialValues { - self.coset_fft_with_options(shift, None, None) - } - - /// Returns the evaluation of the polynomial on the coset `shift*H`. - pub fn coset_fft_with_options( - &self, - shift: F, - zero_factor: Option, - root_table: Option<&FftRootTable>, - ) -> PolynomialValues { - let modified_poly: Self = shift - .powers() - .zip(&self.coeffs) - .map(|(r, &c)| r * c) - .collect::>() - .into(); - modified_poly.fft_with_options(zero_factor, root_table) - } - - pub fn to_extension(&self) -> PolynomialCoeffs - where - F: Extendable, - { - PolynomialCoeffs::new(self.coeffs.iter().map(|&c| c.into()).collect()) - } - - pub fn mul_extension(&self, rhs: F::Extension) -> PolynomialCoeffs - where - F: Extendable, - { - PolynomialCoeffs::new(self.coeffs.iter().map(|&c| rhs.scalar_mul(c)).collect()) - } -} - -impl PartialEq for PolynomialCoeffs { - fn eq(&self, other: &Self) -> bool { - let max_terms = self.coeffs.len().max(other.coeffs.len()); - for i in 0..max_terms { - let self_i = self.coeffs.get(i).cloned().unwrap_or(F::ZERO); - let other_i = other.coeffs.get(i).cloned().unwrap_or(F::ZERO); - if self_i != other_i { - return false; - } - } - true - } -} - -impl Eq for PolynomialCoeffs {} - -impl From> for PolynomialCoeffs { - fn from(coeffs: Vec) -> Self { - Self::new(coeffs) - } -} - -impl Add for &PolynomialCoeffs { - type Output = PolynomialCoeffs; - - fn add(self, rhs: Self) -> Self::Output { - let len = max(self.len(), rhs.len()); - let a = self.padded(len).coeffs; - let b = rhs.padded(len).coeffs; - let coeffs = a.into_iter().zip(b).map(|(x, y)| x + y).collect(); - PolynomialCoeffs::new(coeffs) - } -} - -impl Sum for PolynomialCoeffs { - fn sum>(iter: I) -> Self { - iter.fold(Self::empty(), |acc, p| &acc + &p) - } -} - -impl Sub for &PolynomialCoeffs { - type Output = PolynomialCoeffs; - - fn sub(self, rhs: Self) -> Self::Output { - let len = max(self.len(), rhs.len()); - let mut coeffs = self.padded(len).coeffs; - for (i, &c) in rhs.coeffs.iter().enumerate() { - coeffs[i] -= c; - } - PolynomialCoeffs::new(coeffs) - } -} - -impl AddAssign for PolynomialCoeffs { - fn add_assign(&mut self, rhs: Self) { - let len = max(self.len(), rhs.len()); - self.coeffs.resize(len, F::ZERO); - for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) { - *l += r; - } - } -} - -impl AddAssign<&Self> for PolynomialCoeffs { - fn add_assign(&mut self, rhs: &Self) { - let len = max(self.len(), rhs.len()); - self.coeffs.resize(len, F::ZERO); - for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) { - *l += r; - } - } -} - -impl SubAssign for PolynomialCoeffs { - fn sub_assign(&mut self, rhs: Self) { - let len = max(self.len(), rhs.len()); - self.coeffs.resize(len, F::ZERO); - for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) { - *l -= r; - } - } -} - -impl SubAssign<&Self> for PolynomialCoeffs { - fn sub_assign(&mut self, rhs: &Self) { - let len = max(self.len(), rhs.len()); - self.coeffs.resize(len, F::ZERO); - for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) { - *l -= r; - } - } -} - -impl Mul for &PolynomialCoeffs { - type Output = PolynomialCoeffs; - - fn mul(self, rhs: F) -> Self::Output { - let coeffs = self.coeffs.iter().map(|&x| rhs * x).collect(); - PolynomialCoeffs::new(coeffs) - } -} - -impl MulAssign for PolynomialCoeffs { - fn mul_assign(&mut self, rhs: F) { - self.coeffs.iter_mut().for_each(|x| *x *= rhs); - } -} - -impl Mul for &PolynomialCoeffs { - type Output = PolynomialCoeffs; - - #[allow(clippy::suspicious_arithmetic_impl)] - fn mul(self, rhs: Self) -> Self::Output { - let new_len = (self.len() + rhs.len()).next_power_of_two(); - let a = self.padded(new_len); - let b = rhs.padded(new_len); - let a_evals = a.fft(); - let b_evals = b.fft(); - - let mul_evals: Vec = a_evals - .values - .into_iter() - .zip(b_evals.values) - .map(|(pa, pb)| pa * pb) - .collect(); - ifft(&mul_evals.into()) - } -} - -#[cfg(test)] -mod tests { - use std::time::Instant; - - use rand::{thread_rng, Rng}; - - use super::*; - use crate::field::goldilocks_field::GoldilocksField; - - #[test] - fn test_trimmed() { - type F = GoldilocksField; - - assert_eq!( - PolynomialCoeffs:: { coeffs: vec![] }.trimmed(), - PolynomialCoeffs:: { coeffs: vec![] } - ); - assert_eq!( - PolynomialCoeffs:: { - coeffs: vec![F::ZERO] - } - .trimmed(), - PolynomialCoeffs:: { coeffs: vec![] } - ); - assert_eq!( - PolynomialCoeffs:: { - coeffs: vec![F::ONE, F::TWO, F::ZERO, F::ZERO] - } - .trimmed(), - PolynomialCoeffs:: { - coeffs: vec![F::ONE, F::TWO] - } - ); - } - - #[test] - fn test_coset_fft() { - type F = GoldilocksField; - - let k = 8; - let n = 1 << k; - let poly = PolynomialCoeffs::new(F::rand_vec(n)); - let shift = F::rand(); - let coset_evals = poly.coset_fft(shift).values; - - let generator = F::primitive_root_of_unity(k); - let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) - .into_iter() - .map(|x| poly.eval(x)) - .collect::>(); - assert_eq!(coset_evals, naive_coset_evals); - - let ifft_coeffs = PolynomialValues::new(coset_evals).coset_ifft(shift); - assert_eq!(poly, ifft_coeffs); - } - - #[test] - fn test_coset_ifft() { - type F = GoldilocksField; - - let k = 8; - let n = 1 << k; - let evals = PolynomialValues::new(F::rand_vec(n)); - let shift = F::rand(); - let coeffs = evals.coset_ifft(shift); - - let generator = F::primitive_root_of_unity(k); - let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) - .into_iter() - .map(|x| coeffs.eval(x)) - .collect::>(); - assert_eq!(evals, naive_coset_evals.into()); - - let fft_evals = coeffs.coset_fft(shift); - assert_eq!(evals, fft_evals); - } - - #[test] - fn test_polynomial_multiplication() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000)); - let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); - let b = PolynomialCoeffs::new(F::rand_vec(b_deg)); - let m1 = &a * &b; - let m2 = &a * &b; - for _ in 0..1000 { - let x = F::rand(); - assert_eq!(m1.eval(x), a.eval(x) * b.eval(x)); - assert_eq!(m2.eval(x), a.eval(x) * b.eval(x)); - } - } - - #[test] - fn test_inv_mod_xn() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let a_deg = rng.gen_range(1..1_000); - let n = rng.gen_range(1..1_000); - let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); - let b = a.inv_mod_xn(n); - let mut m = &a * &b; - m.coeffs.drain(n..); - m.trim(); - assert_eq!( - m, - PolynomialCoeffs::new(vec![F::ONE]), - "a: {:#?}, b:{:#?}, n:{:#?}, m:{:#?}", - a, - b, - n, - m - ); - } - - #[test] - fn test_polynomial_long_division() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000)); - let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); - let b = PolynomialCoeffs::new(F::rand_vec(b_deg)); - let (q, r) = a.div_rem_long_division(&b); - for _ in 0..1000 { - let x = F::rand(); - assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x)); - } - } - - #[test] - fn test_polynomial_division() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000)); - let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); - let b = PolynomialCoeffs::new(F::rand_vec(b_deg)); - let (q, r) = a.div_rem(&b); - for _ in 0..1000 { - let x = F::rand(); - assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x)); - } - } - - #[test] - fn test_polynomial_division_by_constant() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let a_deg = rng.gen_range(1..10_000); - let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); - let b = PolynomialCoeffs::from(vec![F::rand()]); - let (q, r) = a.div_rem(&b); - for _ in 0..1000 { - let x = F::rand(); - assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x)); - } - } - - // Test to see which polynomial division method is faster for divisions of the type - // `(X^n - 1)/(X - a) - #[test] - fn test_division_linear() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let l = 14; - let n = 1 << l; - let g = F::primitive_root_of_unity(l); - let xn_minus_one = { - let mut xn_min_one_vec = vec![F::ZERO; n + 1]; - xn_min_one_vec[n] = F::ONE; - xn_min_one_vec[0] = F::NEG_ONE; - PolynomialCoeffs::new(xn_min_one_vec) - }; - - let a = g.exp_u64(rng.gen_range(0..(n as u64))); - let denom = PolynomialCoeffs::new(vec![-a, F::ONE]); - let now = Instant::now(); - xn_minus_one.div_rem(&denom); - println!("Division time: {:?}", now.elapsed()); - let now = Instant::now(); - xn_minus_one.div_rem_long_division(&denom); - println!("Division time: {:?}", now.elapsed()); - } - - #[test] - fn eq() { - type F = GoldilocksField; - assert_eq!( - PolynomialCoeffs::::new(vec![]), - PolynomialCoeffs::new(vec![]) - ); - assert_eq!( - PolynomialCoeffs::::new(vec![F::ZERO]), - PolynomialCoeffs::new(vec![F::ZERO]) - ); - assert_eq!( - PolynomialCoeffs::::new(vec![]), - PolynomialCoeffs::new(vec![F::ZERO]) - ); - assert_eq!( - PolynomialCoeffs::::new(vec![F::ZERO]), - PolynomialCoeffs::new(vec![]) - ); - assert_eq!( - PolynomialCoeffs::::new(vec![F::ZERO]), - PolynomialCoeffs::new(vec![F::ZERO, F::ZERO]) - ); - assert_eq!( - PolynomialCoeffs::::new(vec![F::ONE]), - PolynomialCoeffs::new(vec![F::ONE, F::ZERO]) - ); - assert_ne!( - PolynomialCoeffs::::new(vec![]), - PolynomialCoeffs::new(vec![F::ONE]) - ); - assert_ne!( - PolynomialCoeffs::::new(vec![F::ZERO]), - PolynomialCoeffs::new(vec![F::ZERO, F::ONE]) - ); - assert_ne!( - PolynomialCoeffs::::new(vec![F::ZERO]), - PolynomialCoeffs::new(vec![F::ONE, F::ZERO]) - ); - } -} From 073fe7a6d9375b17235b7ace6af5cfb2ad0b11db Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 13 Dec 2021 16:40:00 +0100 Subject: [PATCH 198/202] New clippy lints --- src/field/fft.rs | 2 +- src/gadgets/sorting.rs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/field/fft.rs b/src/field/fft.rs index 76e0fd42..ba94f6a7 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -43,7 +43,7 @@ fn fft_dispatch( } else { Some(fft_root_table(input.len())) }; - let used_root_table = root_table.or_else(|| computed_root_table.as_ref()).unwrap(); + let used_root_table = root_table.or(computed_root_table.as_ref()).unwrap(); fft_classic(input, zero_factor.unwrap_or(0), used_root_table) } diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 2059a888..c4378ab9 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -128,8 +128,7 @@ impl, const D: usize> SimpleGenerator fn dependencies(&self) -> Vec { self.input_ops .iter() - .map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value]) - .flatten() + .flat_map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value]) .collect() } From 920d5995c7eb04e8b8174a17a1829ecda4e9188d Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 13 Dec 2021 16:46:49 +0100 Subject: [PATCH 199/202] Replace `bits()` fn with `BITS` const --- src/field/field_types.rs | 6 +----- src/fri/recursive_verifier.rs | 2 +- src/gates/base_sum.rs | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/field/field_types.rs b/src/field/field_types.rs index a3affc13..b6d9e700 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -59,6 +59,7 @@ pub trait Field: /// Generator of a multiplicative subgroup of order `2^TWO_ADICITY`. const POWER_OF_TWO_GENERATOR: Self; + /// The bit length of the field order. const BITS: usize; fn order() -> BigUint; @@ -407,11 +408,6 @@ pub trait Field: pub trait PrimeField: Field { const ORDER: u64; - /// The number of bits required to encode any field element. - fn bits() -> usize { - bits_u64(Self::NEG_ONE.to_canonical_u64()) - } - fn to_canonical_u64(&self) -> u64; fn to_noncanonical_u64(&self) -> u64; diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index c2684725..dc4214b5 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -338,7 +338,7 @@ impl, const D: usize> CircuitBuilder { // verify that this has a negligible impact on soundness error. Self::assert_noncanonical_indices_ok(&common_data.config); let x_index = challenger.get_challenge(self); - let mut x_index_bits = self.low_bits(x_index, n_log, F::bits()); + let mut x_index_bits = self.low_bits(x_index, n_log, F::BITS); let cap_index = self.le_sum(x_index_bits[x_index_bits.len() - common_data.config.cap_height..].iter()); diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index 2ab5345b..7ea235fa 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -24,7 +24,7 @@ impl BaseSumGate { } pub fn new_from_config(config: &CircuitConfig) -> Self { - let num_limbs = F::bits().min(config.num_routed_wires - Self::START_LIMBS); + let num_limbs = F::BITS.min(config.num_routed_wires - Self::START_LIMBS); Self::new(num_limbs) } From 6863eea74eb17107f580c38a542a45f3b895fcb8 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 13 Dec 2021 16:51:36 +0100 Subject: [PATCH 200/202] New clippy lints --- src/field/fft.rs | 2 +- src/gadgets/sorting.rs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/field/fft.rs b/src/field/fft.rs index 76e0fd42..ba94f6a7 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -43,7 +43,7 @@ fn fft_dispatch( } else { Some(fft_root_table(input.len())) }; - let used_root_table = root_table.or_else(|| computed_root_table.as_ref()).unwrap(); + let used_root_table = root_table.or(computed_root_table.as_ref()).unwrap(); fft_classic(input, zero_factor.unwrap_or(0), used_root_table) } diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 2059a888..c4378ab9 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -128,8 +128,7 @@ impl, const D: usize> SimpleGenerator fn dependencies(&self) -> Vec { self.input_ops .iter() - .map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value]) - .flatten() + .flat_map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value]) .collect() } From 9211bcfed50c405381a3892e30010732ba60d138 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 14 Dec 2021 17:12:14 +0100 Subject: [PATCH 201/202] Move characteristic to its own fn --- src/field/extension_field/quadratic.rs | 6 ++++-- src/field/extension_field/quartic.rs | 6 ++++-- src/field/field_types.rs | 17 +++++++++-------- src/field/goldilocks_field.rs | 6 ++++-- src/field/secp256k1_base.rs | 5 ++++- src/field/secp256k1_scalar.rs | 5 ++++- 6 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index 16743f12..2243612e 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -59,8 +59,7 @@ impl> Field for QuadraticExtension { // long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as // `2(2n + 1)`, which has a 2-adicity of 1. const TWO_ADICITY: usize = F::TWO_ADICITY + 1; - const CHARACTERISTIC_WITH_TWO_ADICITY: Option<(u64, usize)> = - F::CHARACTERISTIC_WITH_TWO_ADICITY; + const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); @@ -70,6 +69,9 @@ impl> Field for QuadraticExtension { fn order() -> BigUint { F::order() * F::order() } + fn characteristic() -> BigUint { + F::characteristic() + } #[inline(always)] fn square(&self) -> Self { diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index 77329c94..781f79f5 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -61,8 +61,7 @@ impl> Field for QuarticExtension { // `2(2n + 1)`, which has a 2-adicity of 1. A similar argument can show that `p^2 + 1` also has // a 2-adicity of 1. const TWO_ADICITY: usize = F::TWO_ADICITY + 2; - const CHARACTERISTIC_WITH_TWO_ADICITY: Option<(u64, usize)> = - F::CHARACTERISTIC_WITH_TWO_ADICITY; + const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); @@ -72,6 +71,9 @@ impl> Field for QuarticExtension { fn order() -> BigUint { F::order().pow(4u32) } + fn characteristic() -> BigUint { + F::characteristic() + } #[inline(always)] fn square(&self) -> Self { diff --git a/src/field/field_types.rs b/src/field/field_types.rs index cce96bec..f3d1c946 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -4,7 +4,7 @@ use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use num::bigint::BigUint; -use num::{Integer, One, Zero}; +use num::{Integer, One, ToPrimitive, Zero}; use rand::Rng; use serde::de::DeserializeOwned; use serde::Serialize; @@ -52,7 +52,7 @@ pub trait Field: /// The field's characteristic and it's 2-adicity. /// Set to `None` when the characteristic doesn't fit in a u64. - const CHARACTERISTIC_WITH_TWO_ADICITY: Option<(u64, usize)>; + const CHARACTERISTIC_TWO_ADICITY: usize; /// Generator of the entire multiplicative group, i.e. all non-zero elements. const MULTIPLICATIVE_GROUP_GENERATOR: Self; @@ -62,6 +62,7 @@ pub trait Field: const BITS: usize; fn order() -> BigUint; + fn characteristic() -> BigUint; #[inline] fn is_zero(&self) -> bool { @@ -204,24 +205,24 @@ pub trait Field: // exp exceeds t, we repeatedly multiply by 2^-t and reduce // exp until it's in the right range. - if let Some((p, two_adicity)) = Self::CHARACTERISTIC_WITH_TWO_ADICITY { + if let Some(p) = Self::characteristic().to_u64() { // NB: The only reason this is split into two cases is to save // the multiplication (and possible calculation of // inverse_2_pow_adicity) in the usual case that exp <= // TWO_ADICITY. Can remove the branch and simplify if that // saving isn't worth it. - if exp > two_adicity { + if exp > Self::CHARACTERISTIC_TWO_ADICITY { // NB: This should be a compile-time constant let inverse_2_pow_adicity: Self = - Self::from_canonical_u64(p - ((p - 1) >> two_adicity)); + Self::from_canonical_u64(p - ((p - 1) >> Self::CHARACTERISTIC_TWO_ADICITY)); let mut res = inverse_2_pow_adicity; - let mut e = exp - two_adicity; + let mut e = exp - Self::CHARACTERISTIC_TWO_ADICITY; - while e > two_adicity { + while e > Self::CHARACTERISTIC_TWO_ADICITY { res *= inverse_2_pow_adicity; - e -= two_adicity; + e -= Self::CHARACTERISTIC_TWO_ADICITY; } res * Self::from_canonical_u64(p - ((p - 1) >> e)) } else { diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index 14c0a281..d963fb9e 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -68,8 +68,7 @@ impl Field for GoldilocksField { const NEG_ONE: Self = Self(Self::ORDER - 1); const TWO_ADICITY: usize = 32; - const CHARACTERISTIC_WITH_TWO_ADICITY: Option<(u64, usize)> = - Some((Self::ORDER, Self::TWO_ADICITY)); + const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY; // Sage: `g = GF(p).multiplicative_generator()` const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(7); @@ -86,6 +85,9 @@ impl Field for GoldilocksField { fn order() -> BigUint { Self::ORDER.into() } + fn characteristic() -> BigUint { + Self::order() + } #[inline(always)] fn try_inverse(&self) -> Option { diff --git a/src/field/secp256k1_base.rs b/src/field/secp256k1_base.rs index 3e0d0ef0..0d79000f 100644 --- a/src/field/secp256k1_base.rs +++ b/src/field/secp256k1_base.rs @@ -78,7 +78,7 @@ impl Field for Secp256K1Base { ]); const TWO_ADICITY: usize = 1; - const CHARACTERISTIC_WITH_TWO_ADICITY: Option<(u64, usize)> = None; + const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY; // Sage: `g = GF(p).multiplicative_generator()` const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]); @@ -94,6 +94,9 @@ impl Field for Secp256K1Base { 0xFFFFFFFF, ]) } + fn characteristic() -> BigUint { + Self::order() + } fn try_inverse(&self) -> Option { if self.is_zero() { diff --git a/src/field/secp256k1_scalar.rs b/src/field/secp256k1_scalar.rs index 595a27a3..a5b7a315 100644 --- a/src/field/secp256k1_scalar.rs +++ b/src/field/secp256k1_scalar.rs @@ -81,7 +81,7 @@ impl Field for Secp256K1Scalar { ]); const TWO_ADICITY: usize = 6; - const CHARACTERISTIC_WITH_TWO_ADICITY: Option<(u64, usize)> = None; + const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY; // Sage: `g = GF(p).multiplicative_generator()` const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([7, 0, 0, 0]); @@ -103,6 +103,9 @@ impl Field for Secp256K1Scalar { 0xFFFFFFFF, ]) } + fn characteristic() -> BigUint { + Self::order() + } fn try_inverse(&self) -> Option { if self.is_zero() { From 357eea8df584ae90fccff7110109c97ce07d1e7e Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Wed, 15 Dec 2021 21:59:16 -0800 Subject: [PATCH 202/202] Fix build on main (#396) --- benches/field_arithmetic.rs | 2 -- benches/hashing.rs | 1 - src/field/goldilocks_field.rs | 2 ++ src/hash/arch/aarch64/poseidon_goldilocks_neon.rs | 1 + src/lib.rs | 2 -- src/util/mod.rs | 3 ++- 6 files changed, 5 insertions(+), 6 deletions(-) diff --git a/benches/field_arithmetic.rs b/benches/field_arithmetic.rs index 2fb4a24b..8308e427 100644 --- a/benches/field_arithmetic.rs +++ b/benches/field_arithmetic.rs @@ -1,5 +1,3 @@ -#![feature(destructuring_assignment)] - use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use plonky2::field::extension_field::quartic::QuarticExtension; use plonky2::field::field_types::Field; diff --git a/benches/hashing.rs b/benches/hashing.rs index c229972e..583c36b6 100644 --- a/benches/hashing.rs +++ b/benches/hashing.rs @@ -1,4 +1,3 @@ -#![feature(destructuring_assignment)] #![feature(generic_const_exprs)] use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index d963fb9e..9e93d1f1 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -323,6 +323,7 @@ impl RichField for GoldilocksField {} #[inline(always)] #[cfg(target_arch = "x86_64")] unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { + use std::arch::asm; let res_wrapped: u64; let adjustment: u64; asm!( @@ -363,6 +364,7 @@ unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { #[inline(always)] #[cfg(target_arch = "x86_64")] unsafe fn sub_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { + use std::arch::asm; let res_wrapped: u64; let adjustment: u64; asm!( diff --git a/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs b/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs index 0aaa13a6..6437818b 100644 --- a/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs +++ b/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs @@ -1,6 +1,7 @@ #![allow(clippy::assertions_on_constants)] use std::arch::aarch64::*; +use std::arch::asm; use static_assertions::const_assert; use unroll::unroll_for_loops; diff --git a/src/lib.rs b/src/lib.rs index b0158d7a..291e6422 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,9 +4,7 @@ #![allow(clippy::too_many_arguments)] #![allow(clippy::len_without_is_empty)] #![allow(clippy::needless_range_loop)] -#![feature(asm)] #![feature(asm_sym)] -#![feature(destructuring_assignment)] #![feature(generic_const_exprs)] #![feature(specialization)] #![feature(stdsimd)] diff --git a/src/util/mod.rs b/src/util/mod.rs index fca6b728..3f7c5dd1 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,4 +1,5 @@ -use core::hint::unreachable_unchecked; +use std::arch::asm; +use std::hint::unreachable_unchecked; use crate::field::field_types::Field; use crate::polynomial::PolynomialValues;