diff --git a/.github/workflows/continuous-integration-workflow.yml b/.github/workflows/continuous-integration-workflow.yml index dd79f33e..bf54ab3d 100644 --- a/.github/workflows/continuous-integration-workflow.yml +++ b/.github/workflows/continuous-integration-workflow.yml @@ -28,9 +28,11 @@ jobs: with: command: test args: --all + env: + RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y lints: - name: Formatting + name: Formatting and Clippy runs-on: ubuntu-latest if: "! contains(toJSON(github.event.commits.*.message), '[skip-ci]')" steps: @@ -43,10 +45,17 @@ jobs: profile: minimal toolchain: nightly override: true - components: rustfmt + components: rustfmt, clippy - name: Run cargo fmt uses: actions-rs/cargo@v1 with: command: fmt args: --all -- --check + + - name: Run cargo clippy + uses: actions-rs/cargo@v1 + with: + command: clippy + args: --all-features --all-targets -- -D warnings -A incomplete-features + diff --git a/Cargo.toml b/Cargo.toml index a55bd899..ac383b4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0" repository = "https://github.com/mir-protocol/plonky2" keywords = ["cryptography", "SNARK", "FRI"] categories = ["cryptography"] -edition = "2018" +edition = "2021" default-run = "bench_recursion" [dependencies] @@ -28,6 +28,9 @@ serde_cbor = "0.11.1" keccak-hash = "0.8.0" static_assertions = "1.1.0" +[target.'cfg(not(target_env = "msvc"))'.dependencies] +jemallocator = "0.3.2" + [dev-dependencies] criterion = "0.3.5" tynm = "0.1.6" diff --git a/benches/ffts.rs b/benches/ffts.rs index 8492cfe9..745d53a8 100644 --- a/benches/ffts.rs +++ b/benches/ffts.rs @@ -1,7 +1,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use plonky2::field::field_types::Field; use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::polynomial::polynomial::PolynomialCoeffs; +use plonky2::polynomial::PolynomialCoeffs; use tynm::type_name; pub(crate) fn bench_ffts(c: &mut Criterion) { diff --git a/benches/field_arithmetic.rs b/benches/field_arithmetic.rs index ebecb871..8308e427 100644 --- a/benches/field_arithmetic.rs +++ b/benches/field_arithmetic.rs @@ -1,5 +1,3 @@ -#![feature(destructuring_assignment)] - use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use plonky2::field::extension_field::quartic::QuarticExtension; use plonky2::field::field_types::Field; @@ -112,6 +110,66 @@ pub(crate) fn bench_field(c: &mut Criterion) { c.bench_function(&format!("try_inverse<{}>", type_name::()), |b| { b.iter_batched(|| F::rand(), |x| x.try_inverse(), BatchSize::SmallInput) }); + + c.bench_function( + &format!("batch_multiplicative_inverse-tiny<{}>", type_name::()), + |b| { + b.iter_batched( + || (0..2).into_iter().map(|_| F::rand()).collect::>(), + |x| F::batch_multiplicative_inverse(&x), + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!("batch_multiplicative_inverse-small<{}>", type_name::()), + |b| { + b.iter_batched( + || (0..4).into_iter().map(|_| F::rand()).collect::>(), + |x| F::batch_multiplicative_inverse(&x), + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!("batch_multiplicative_inverse-medium<{}>", type_name::()), + |b| { + b.iter_batched( + || (0..16).into_iter().map(|_| F::rand()).collect::>(), + |x| F::batch_multiplicative_inverse(&x), + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!("batch_multiplicative_inverse-large<{}>", type_name::()), + |b| { + b.iter_batched( + || (0..256).into_iter().map(|_| F::rand()).collect::>(), + |x| F::batch_multiplicative_inverse(&x), + BatchSize::LargeInput, + ) + }, + ); + + c.bench_function( + &format!("batch_multiplicative_inverse-huge<{}>", type_name::()), + |b| { + b.iter_batched( + || { + (0..65536) + .into_iter() + .map(|_| F::rand()) + .collect::>() + }, + |x| F::batch_multiplicative_inverse(&x), + BatchSize::LargeInput, + ) + }, + ); } fn criterion_benchmark(c: &mut Criterion) { diff --git a/benches/hashing.rs b/benches/hashing.rs index 5669e50b..583c36b6 100644 --- a/benches/hashing.rs +++ b/benches/hashing.rs @@ -1,4 +1,3 @@ -#![feature(destructuring_assignment)] #![feature(generic_const_exprs)] use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; @@ -19,7 +18,7 @@ pub(crate) fn bench_gmimc, const WIDTH: usize>(c: &mut Criterion pub(crate) fn bench_poseidon, const WIDTH: usize>(c: &mut Criterion) where - [(); WIDTH - 1]: , + [(); WIDTH - 1]:, { c.bench_function(&format!("poseidon<{}, {}>", type_name::(), WIDTH), |b| { b.iter_batched( diff --git a/src/bin/bench_ldes.rs b/src/bin/bench_ldes.rs index d121831b..dbcfa6df 100644 --- a/src/bin/bench_ldes.rs +++ b/src/bin/bench_ldes.rs @@ -2,7 +2,7 @@ use std::time::Instant; use plonky2::field::field_types::Field; use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::polynomial::polynomial::PolynomialValues; +use plonky2::polynomial::PolynomialValues; use rayon::prelude::*; type F = GoldilocksField; diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index 30da9687..e75ea135 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -24,6 +24,7 @@ fn bench_prove, const D: usize>() -> Result<()> { num_wires: 126, num_routed_wires: 33, constant_gate_size: 6, + use_base_arithmetic_gate: false, security_bits: 128, rate_bits: 3, num_challenges: 3, diff --git a/src/bin/generate_constants.rs b/src/bin/generate_constants.rs index 60028741..89630fc7 100644 --- a/src/bin/generate_constants.rs +++ b/src/bin/generate_constants.rs @@ -1,5 +1,7 @@ //! Generates random constants using ChaCha20, seeded with zero. +#![allow(clippy::needless_range_loop)] + use plonky2::field::field_types::PrimeField; use plonky2::field::goldilocks_field::GoldilocksField; use rand::{Rng, SeedableRng}; diff --git a/src/curve/curve_adds.rs b/src/curve/curve_adds.rs new file mode 100644 index 00000000..f25d3847 --- /dev/null +++ b/src/curve/curve_adds.rs @@ -0,0 +1,156 @@ +use std::ops::Add; + +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; +use crate::field::field_types::Field; + +impl Add> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: ProjectivePoint) -> Self::Output { + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + } = self; + let ProjectivePoint { + x: x2, + y: y2, + z: z2, + } = rhs; + + if z1 == C::BaseField::ZERO { + return rhs; + } + if z2 == C::BaseField::ZERO { + return self; + } + + let x1z2 = x1 * z2; + let y1z2 = y1 * z2; + let x2z1 = x2 * z1; + let y2z1 = y2 * z1; + + // Check if we're doubling or adding inverses. + if x1z2 == x2z1 { + if y1z2 == y2z1 { + // TODO: inline to avoid redundant muls. + return self.double(); + } + if y1z2 == -y2z1 { + return ProjectivePoint::ZERO; + } + } + + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/add-1998-cmo-2 + let z1z2 = z1 * z2; + let u = y2z1 - y1z2; + let uu = u.square(); + let v = x2z1 - x1z2; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1z2; + let a = uu * z1z2 - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1z2; + let z3 = vvv * z1z2; + ProjectivePoint::nonzero(x3, y3, z3) + } +} + +impl Add> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: AffinePoint) -> Self::Output { + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + } = self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = rhs; + + if z1 == C::BaseField::ZERO { + return rhs.to_projective(); + } + if zero2 { + return self; + } + + let x2z1 = x2 * z1; + let y2z1 = y2 * z1; + + // Check if we're doubling or adding inverses. + if x1 == x2z1 { + if y1 == y2z1 { + // TODO: inline to avoid redundant muls. + return self.double(); + } + if y1 == -y2z1 { + return ProjectivePoint::ZERO; + } + } + + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/madd-1998-cmo + let u = y2z1 - y1; + let uu = u.square(); + let v = x2z1 - x1; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1; + let a = uu * z1 - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1; + let z3 = vvv * z1; + ProjectivePoint::nonzero(x3, y3, z3) + } +} + +impl Add> for AffinePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: AffinePoint) -> Self::Output { + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = rhs; + + if zero1 { + return rhs.to_projective(); + } + if zero2 { + return self.to_projective(); + } + + // Check if we're doubling or adding inverses. + if x1 == x2 { + if y1 == y2 { + return self.to_projective().double(); + } + if y1 == -y2 { + return ProjectivePoint::ZERO; + } + } + + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/mmadd-1998-cmo + let u = y2 - y1; + let uu = u.square(); + let v = x2 - x1; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1; + let a = uu - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1; + let z3 = vvv; + ProjectivePoint::nonzero(x3, y3, z3) + } +} diff --git a/src/curve/curve_msm.rs b/src/curve/curve_msm.rs new file mode 100644 index 00000000..d2cb8049 --- /dev/null +++ b/src/curve/curve_msm.rs @@ -0,0 +1,263 @@ +use itertools::Itertools; +use rayon::prelude::*; + +use crate::curve::curve_summation::affine_multisummation_best; +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; +use crate::field::field_types::Field; + +/// In Yao's method, we compute an affine summation for each digit. In a parallel setting, it would +/// be easiest to assign individual summations to threads, but this would be sub-optimal because +/// multi-summations can be more efficient than repeating individual summations (see +/// `affine_multisummation_best`). Thus we divide digits into large chunks, and assign chunks of +/// digits to threads. Note that there is a delicate balance here, as large chunks can result in +/// uneven distributions of work among threads. +const DIGITS_PER_CHUNK: usize = 80; + +#[derive(Clone, Debug)] +pub struct MsmPrecomputation { + /// For each generator (in the order they were passed to `msm_precompute`), contains a vector + /// of powers, i.e. [(2^w)^i] for i < DIGITS. + // TODO: Use compressed coordinates here. + powers_per_generator: Vec>>, + + /// The window size. + w: usize, +} + +pub fn msm_precompute( + generators: &[ProjectivePoint], + w: usize, +) -> MsmPrecomputation { + MsmPrecomputation { + powers_per_generator: generators + .into_par_iter() + .map(|&g| precompute_single_generator(g, w)) + .collect(), + w, + } +} + +fn precompute_single_generator(g: ProjectivePoint, w: usize) -> Vec> { + let digits = (C::ScalarField::BITS + w - 1) / w; + let mut powers: Vec> = Vec::with_capacity(digits); + powers.push(g); + for i in 1..digits { + let mut power_i_proj = powers[i - 1]; + for _j in 0..w { + power_i_proj = power_i_proj.double(); + } + powers.push(power_i_proj); + } + ProjectivePoint::batch_to_affine(&powers) +} + +pub fn msm_parallel( + scalars: &[C::ScalarField], + generators: &[ProjectivePoint], + w: usize, +) -> ProjectivePoint { + let precomputation = msm_precompute(generators, w); + msm_execute_parallel(&precomputation, scalars) +} + +pub fn msm_execute( + precomputation: &MsmPrecomputation, + scalars: &[C::ScalarField], +) -> ProjectivePoint { + assert_eq!(precomputation.powers_per_generator.len(), scalars.len()); + let w = precomputation.w; + let digits = (C::ScalarField::BITS + w - 1) / w; + let base = 1 << w; + + // This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use + // extremely large windows, the repeated scans in Yao's method could be more expensive than the + // actual group operations. To avoid this, we store a multimap from each possible digit to the + // positions in which that digit occurs in the scalars. These positions have the form (i, j), + // where i is the index of the generator and j is an index into the digits of the scalar + // associated with that generator. + let mut digit_occurrences: Vec> = Vec::with_capacity(digits); + for _i in 0..base { + digit_occurrences.push(Vec::new()); + } + for (i, scalar) in scalars.iter().enumerate() { + let digits = to_digits::(scalar, w); + for (j, &digit) in digits.iter().enumerate() { + digit_occurrences[digit].push((i, j)); + } + } + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + + for digit in (1..base).rev() { + for &(i, j) in &digit_occurrences[digit] { + u = u + precomputation.powers_per_generator[i][j]; + } + y = y + u; + } + + y +} + +pub fn msm_execute_parallel( + precomputation: &MsmPrecomputation, + scalars: &[C::ScalarField], +) -> ProjectivePoint { + assert_eq!(precomputation.powers_per_generator.len(), scalars.len()); + let w = precomputation.w; + let digits = (C::ScalarField::BITS + w - 1) / w; + let base = 1 << w; + + // This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use + // extremely large windows, the repeated scans in Yao's method could be more expensive than the + // actual group operations. To avoid this, we store a multimap from each possible digit to the + // positions in which that digit occurs in the scalars. These positions have the form (i, j), + // where i is the index of the generator and j is an index into the digits of the scalar + // associated with that generator. + let mut digit_occurrences: Vec> = Vec::with_capacity(digits); + for _i in 0..base { + digit_occurrences.push(Vec::new()); + } + for (i, scalar) in scalars.iter().enumerate() { + let digits = to_digits::(scalar, w); + for (j, &digit) in digits.iter().enumerate() { + digit_occurrences[digit].push((i, j)); + } + } + + // For each digit, we add up the powers associated with all occurrences that digit. + let digits: Vec = (0..base).collect(); + let digit_acc: Vec> = digits + .par_chunks(DIGITS_PER_CHUNK) + .flat_map(|chunk| { + let summations: Vec>> = chunk + .iter() + .map(|&digit| { + digit_occurrences[digit] + .iter() + .map(|&(i, j)| precomputation.powers_per_generator[i][j]) + .collect() + }) + .collect(); + affine_multisummation_best(summations) + }) + .collect(); + // println!("Computing the per-digit summations (in parallel) took {}s", start.elapsed().as_secs_f64()); + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + for digit in (1..base).rev() { + u = u + digit_acc[digit]; + y = y + u; + } + // println!("Final summation (sequential) {}s", start.elapsed().as_secs_f64()); + y +} + +pub(crate) fn to_digits(x: &C::ScalarField, w: usize) -> Vec { + let scalar_bits = C::ScalarField::BITS; + let num_digits = (scalar_bits + w - 1) / w; + + // Convert x to a bool array. + let x_canonical: Vec<_> = x + .to_biguint() + .to_u64_digits() + .iter() + .cloned() + .pad_using(scalar_bits / 64, |_| 0) + .collect(); + let mut x_bits = Vec::with_capacity(scalar_bits); + for i in 0..scalar_bits { + x_bits.push((x_canonical[i / 64] >> (i as u64 % 64) & 1) != 0); + } + + let mut digits = Vec::with_capacity(num_digits); + for i in 0..num_digits { + let mut digit = 0; + for j in ((i * w)..((i + 1) * w).min(scalar_bits)).rev() { + digit <<= 1; + digit |= x_bits[j] as usize; + } + digits.push(digit); + } + digits +} + +#[cfg(test)] +mod tests { + use num::BigUint; + + use crate::curve::curve_msm::{msm_execute, msm_precompute, to_digits}; + use crate::curve::curve_types::Curve; + use crate::curve::secp256k1::Secp256K1; + use crate::field::field_types::Field; + use crate::field::secp256k1_scalar::Secp256K1Scalar; + + #[test] + fn test_to_digits() { + let x_canonical = [ + 0b10101010101010101010101010101010, + 0b10101010101010101010101010101010, + 0b11001100110011001100110011001100, + 0b11001100110011001100110011001100, + 0b11110000111100001111000011110000, + 0b11110000111100001111000011110000, + 0b00001111111111111111111111111111, + 0b11111111111111111111111111111111, + ]; + let x = Secp256K1Scalar::from_biguint(BigUint::from_slice(&x_canonical)); + assert_eq!(x.to_biguint().to_u32_digits(), x_canonical); + assert_eq!( + to_digits::(&x, 17), + vec![ + 0b01010101010101010, + 0b10101010101010101, + 0b01010101010101010, + 0b11001010101010101, + 0b01100110011001100, + 0b00110011001100110, + 0b10011001100110011, + 0b11110000110011001, + 0b01111000011110000, + 0b00111100001111000, + 0b00011110000111100, + 0b11111111111111110, + 0b01111111111111111, + 0b11111111111111000, + 0b11111111111111111, + 0b1, + ] + ); + } + + #[test] + fn test_msm() { + let w = 5; + + let generator_1 = Secp256K1::GENERATOR_PROJECTIVE; + let generator_2 = generator_1 + generator_1; + let generator_3 = generator_1 + generator_2; + + let scalar_1 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + 11111111, 22222222, 33333333, 44444444, + ])); + let scalar_2 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + 22222222, 22222222, 33333333, 44444444, + ])); + let scalar_3 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + 33333333, 22222222, 33333333, 44444444, + ])); + + let generators = vec![generator_1, generator_2, generator_3]; + let scalars = vec![scalar_1, scalar_2, scalar_3]; + + let precomputation = msm_precompute(&generators, w); + let result_msm = msm_execute(&precomputation, &scalars); + + let result_naive = Secp256K1::convert(scalar_1) * generator_1 + + Secp256K1::convert(scalar_2) * generator_2 + + Secp256K1::convert(scalar_3) * generator_3; + + assert_eq!(result_msm, result_naive); + } +} diff --git a/src/curve/curve_multiplication.rs b/src/curve/curve_multiplication.rs new file mode 100644 index 00000000..eb5bade1 --- /dev/null +++ b/src/curve/curve_multiplication.rs @@ -0,0 +1,97 @@ +use std::ops::Mul; + +use crate::curve::curve_types::{Curve, CurveScalar, ProjectivePoint}; +use crate::field::field_types::Field; + +const WINDOW_BITS: usize = 4; +const BASE: usize = 1 << WINDOW_BITS; + +fn digits_per_scalar() -> usize { + (C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS +} + +/// Precomputed state used for scalar x ProjectivePoint multiplications, +/// specific to a particular generator. +#[derive(Clone)] +pub struct MultiplicationPrecomputation { + /// [(2^w)^i] g for each i < digits_per_scalar. + powers: Vec>, +} + +impl ProjectivePoint { + pub fn mul_precompute(&self) -> MultiplicationPrecomputation { + let num_digits = digits_per_scalar::(); + let mut powers = Vec::with_capacity(num_digits); + powers.push(*self); + for i in 1..num_digits { + let mut power_i = powers[i - 1]; + for _j in 0..WINDOW_BITS { + power_i = power_i.double(); + } + powers.push(power_i); + } + + MultiplicationPrecomputation { powers } + } + + pub fn mul_with_precomputation( + &self, + scalar: C::ScalarField, + precomputation: MultiplicationPrecomputation, + ) -> Self { + // Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf + let precomputed_powers = precomputation.powers; + + let digits = to_digits::(&scalar); + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + let mut all_summands = Vec::new(); + for j in (1..BASE).rev() { + let mut u_summands = Vec::new(); + for (i, &digit) in digits.iter().enumerate() { + if digit == j as u64 { + u_summands.push(precomputed_powers[i]); + } + } + all_summands.push(u_summands); + } + + let all_sums: Vec> = all_summands + .iter() + .cloned() + .map(|vec| vec.iter().fold(ProjectivePoint::ZERO, |a, &b| a + b)) + .collect(); + for i in 0..all_sums.len() { + u = u + all_sums[i]; + y = y + u; + } + y + } +} + +impl Mul> for CurveScalar { + type Output = ProjectivePoint; + + fn mul(self, rhs: ProjectivePoint) -> Self::Output { + let precomputation = rhs.mul_precompute(); + rhs.mul_with_precomputation(self.0, precomputation) + } +} + +#[allow(clippy::assertions_on_constants)] +fn to_digits(x: &C::ScalarField) -> Vec { + debug_assert!( + 64 % WINDOW_BITS == 0, + "For simplicity, only power-of-two window sizes are handled for now" + ); + let digits_per_u64 = 64 / WINDOW_BITS; + let mut digits = Vec::with_capacity(digits_per_scalar::()); + for limb in x.to_biguint().to_u64_digits() { + for j in 0..digits_per_u64 { + digits.push((limb >> (j * WINDOW_BITS) as u64) % BASE as u64); + } + } + + digits +} diff --git a/src/curve/curve_summation.rs b/src/curve/curve_summation.rs new file mode 100644 index 00000000..c67bc026 --- /dev/null +++ b/src/curve/curve_summation.rs @@ -0,0 +1,237 @@ +use std::iter::Sum; + +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; +use crate::field::field_types::Field; + +impl Sum> for ProjectivePoint { + fn sum>>(iter: I) -> ProjectivePoint { + let points: Vec<_> = iter.collect(); + affine_summation_best(points) + } +} + +impl Sum for ProjectivePoint { + fn sum>>(iter: I) -> ProjectivePoint { + iter.fold(ProjectivePoint::ZERO, |acc, x| acc + x) + } +} + +pub fn affine_summation_best(summation: Vec>) -> ProjectivePoint { + let result = affine_multisummation_best(vec![summation]); + debug_assert_eq!(result.len(), 1); + result[0] +} + +pub fn affine_multisummation_best( + summations: Vec>>, +) -> Vec> { + let pairwise_sums: usize = summations.iter().map(|summation| summation.len() / 2).sum(); + + // This threshold is chosen based on data from the summation benchmarks. + if pairwise_sums < 70 { + affine_multisummation_pairwise(summations) + } else { + affine_multisummation_batch_inversion(summations) + } +} + +/// Adds each pair of points using an affine + affine = projective formula, then adds up the +/// intermediate sums using a projective formula. +pub fn affine_multisummation_pairwise( + summations: Vec>>, +) -> Vec> { + summations + .into_iter() + .map(affine_summation_pairwise) + .collect() +} + +/// Adds each pair of points using an affine + affine = projective formula, then adds up the +/// intermediate sums using a projective formula. +pub fn affine_summation_pairwise(points: Vec>) -> ProjectivePoint { + let mut reduced_points: Vec> = Vec::new(); + for chunk in points.chunks(2) { + match chunk.len() { + 1 => reduced_points.push(chunk[0].to_projective()), + 2 => reduced_points.push(chunk[0] + chunk[1]), + _ => panic!(), + } + } + // TODO: Avoid copying (deref) + reduced_points + .iter() + .fold(ProjectivePoint::ZERO, |sum, x| sum + *x) +} + +/// Computes several summations of affine points by applying an affine group law, except that the +/// divisions are batched via Montgomery's trick. +pub fn affine_summation_batch_inversion( + summation: Vec>, +) -> ProjectivePoint { + let result = affine_multisummation_batch_inversion(vec![summation]); + debug_assert_eq!(result.len(), 1); + result[0] +} + +/// Computes several summations of affine points by applying an affine group law, except that the +/// divisions are batched via Montgomery's trick. +pub fn affine_multisummation_batch_inversion( + summations: Vec>>, +) -> Vec> { + let mut elements_to_invert = Vec::new(); + + // For each pair of points, (x1, y1) and (x2, y2), that we're going to add later, we want to + // invert either y (if the points are equal) or x1 - x2 (otherwise). We will use these later. + for summation in &summations { + let n = summation.len(); + // The special case for n=0 is to avoid underflow. + let range_end = if n == 0 { 0 } else { n - 1 }; + + for i in (0..range_end).step_by(2) { + let p1 = summation[i]; + let p2 = summation[i + 1]; + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = p1; + let AffinePoint { + x: x2, + y: _y2, + zero: zero2, + } = p2; + + if zero1 || zero2 || p1 == -p2 { + // These are trivial cases where we won't need any inverse. + } else if p1 == p2 { + elements_to_invert.push(y1.double()); + } else { + elements_to_invert.push(x1 - x2); + } + } + } + + let inverses: Vec = + C::BaseField::batch_multiplicative_inverse(&elements_to_invert); + + let mut all_reduced_points = Vec::with_capacity(summations.len()); + let mut inverse_index = 0; + for summation in summations { + let n = summation.len(); + let mut reduced_points = Vec::with_capacity((n + 1) / 2); + + // The special case for n=0 is to avoid underflow. + let range_end = if n == 0 { 0 } else { n - 1 }; + + for i in (0..range_end).step_by(2) { + let p1 = summation[i]; + let p2 = summation[i + 1]; + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = p1; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = p2; + + let sum = if zero1 { + p2 + } else if zero2 { + p1 + } else if p1 == -p2 { + AffinePoint::ZERO + } else { + // It's a non-trivial case where we need one of the inverses we computed earlier. + let inverse = inverses[inverse_index]; + inverse_index += 1; + + if p1 == p2 { + // This is the doubling case. + let mut numerator = x1.square().triple(); + if C::A.is_nonzero() { + numerator += C::A; + } + let quotient = numerator * inverse; + let x3 = quotient.square() - x1.double(); + let y3 = quotient * (x1 - x3) - y1; + AffinePoint::nonzero(x3, y3) + } else { + // This is the general case. We use the incomplete addition formulas 4.3 and 4.4. + let quotient = (y1 - y2) * inverse; + let x3 = quotient.square() - x1 - x2; + let y3 = quotient * (x1 - x3) - y1; + AffinePoint::nonzero(x3, y3) + } + }; + reduced_points.push(sum); + } + + // If n is odd, the last point was not part of a pair. + if n % 2 == 1 { + reduced_points.push(summation[n - 1]); + } + + all_reduced_points.push(reduced_points); + } + + // We should have consumed all of the inverses from the batch computation. + debug_assert_eq!(inverse_index, inverses.len()); + + // Recurse with our smaller set of points. + affine_multisummation_best(all_reduced_points) +} + +#[cfg(test)] +mod tests { + use crate::curve::curve_summation::{ + affine_summation_batch_inversion, affine_summation_pairwise, + }; + use crate::curve::curve_types::{Curve, ProjectivePoint}; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_pairwise_affine_summation() { + let g_affine = Secp256K1::GENERATOR_AFFINE; + let g2_affine = (g_affine + g_affine).to_affine(); + let g3_affine = (g_affine + g_affine + g_affine).to_affine(); + let g2_proj = g2_affine.to_projective(); + let g3_proj = g3_affine.to_projective(); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g_affine]), + g2_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g2_affine]), + g3_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g_affine, g_affine]), + g3_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![]), + ProjectivePoint::ZERO + ); + } + + #[test] + fn test_pairwise_affine_summation_batch_inversion() { + let g = Secp256K1::GENERATOR_AFFINE; + let g_proj = g.to_projective(); + assert_eq!( + affine_summation_batch_inversion::(vec![g, g]), + g_proj + g_proj + ); + assert_eq!( + affine_summation_batch_inversion::(vec![g, g, g]), + g_proj + g_proj + g_proj + ); + assert_eq!( + affine_summation_batch_inversion::(vec![]), + ProjectivePoint::ZERO + ); + } +} diff --git a/src/curve/curve_types.rs b/src/curve/curve_types.rs new file mode 100644 index 00000000..ef1f6186 --- /dev/null +++ b/src/curve/curve_types.rs @@ -0,0 +1,260 @@ +use std::fmt::Debug; +use std::ops::Neg; + +use crate::field::field_types::Field; + +// To avoid implementation conflicts from associated types, +// see https://github.com/rust-lang/rust/issues/20400 +pub struct CurveScalar(pub ::ScalarField); + +/// A short Weierstrass curve. +pub trait Curve: 'static + Sync + Sized + Copy + Debug { + type BaseField: Field; + type ScalarField: Field; + + const A: Self::BaseField; + const B: Self::BaseField; + + const GENERATOR_AFFINE: AffinePoint; + + const GENERATOR_PROJECTIVE: ProjectivePoint = ProjectivePoint { + x: Self::GENERATOR_AFFINE.x, + y: Self::GENERATOR_AFFINE.y, + z: Self::BaseField::ONE, + }; + + fn convert(x: Self::ScalarField) -> CurveScalar { + CurveScalar(x) + } + + fn is_safe_curve() -> bool { + // Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0. + (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()) + .is_nonzero() + } +} + +/// A point on a short Weierstrass curve, represented in affine coordinates. +#[derive(Copy, Clone, Debug)] +pub struct AffinePoint { + pub x: C::BaseField, + pub y: C::BaseField, + pub zero: bool, +} + +impl AffinePoint { + pub const ZERO: Self = Self { + x: C::BaseField::ZERO, + y: C::BaseField::ZERO, + zero: true, + }; + + pub fn nonzero(x: C::BaseField, y: C::BaseField) -> Self { + let point = Self { x, y, zero: false }; + debug_assert!(point.is_valid()); + point + } + + pub fn is_valid(&self) -> bool { + let Self { x, y, zero } = *self; + zero || y.square() == x.cube() + C::A * x + C::B + } + + pub fn to_projective(&self) -> ProjectivePoint { + let Self { x, y, zero } = *self; + let z = if zero { + C::BaseField::ZERO + } else { + C::BaseField::ONE + }; + + ProjectivePoint { x, y, z } + } + + pub fn batch_to_projective(affine_points: &[Self]) -> Vec> { + affine_points.iter().map(Self::to_projective).collect() + } + + pub fn double(&self) -> Self { + let AffinePoint { x: x1, y: y1, zero } = *self; + + if zero { + return AffinePoint::ZERO; + } + + let double_y = y1.double(); + let inv_double_y = double_y.inverse(); // (2y)^(-1) + let triple_xx = x1.square().triple(); // 3x^2 + let lambda = (triple_xx + C::A) * inv_double_y; + let x3 = lambda.square() - self.x.double(); + let y3 = lambda * (x1 - x3) - y1; + + Self { + x: x3, + y: y3, + zero: false, + } + } +} + +impl PartialEq for AffinePoint { + fn eq(&self, other: &Self) -> bool { + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = *self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = *other; + if zero1 || zero2 { + return zero1 == zero2; + } + x1 == x2 && y1 == y2 + } +} + +impl Eq for AffinePoint {} + +/// A point on a short Weierstrass curve, represented in projective coordinates. +#[derive(Copy, Clone, Debug)] +pub struct ProjectivePoint { + pub x: C::BaseField, + pub y: C::BaseField, + pub z: C::BaseField, +} + +impl ProjectivePoint { + pub const ZERO: Self = Self { + x: C::BaseField::ZERO, + y: C::BaseField::ONE, + z: C::BaseField::ZERO, + }; + + pub fn nonzero(x: C::BaseField, y: C::BaseField, z: C::BaseField) -> Self { + let point = Self { x, y, z }; + debug_assert!(point.is_valid()); + point + } + + pub fn is_valid(&self) -> bool { + let Self { x, y, z } = *self; + z.is_zero() || y.square() * z == x.cube() + C::A * x * z.square() + C::B * z.cube() + } + + pub fn to_affine(&self) -> AffinePoint { + let Self { x, y, z } = *self; + if z == C::BaseField::ZERO { + AffinePoint::ZERO + } else { + let z_inv = z.inverse(); + AffinePoint::nonzero(x * z_inv, y * z_inv) + } + } + + pub fn batch_to_affine(proj_points: &[Self]) -> Vec> { + let n = proj_points.len(); + let zs: Vec = proj_points.iter().map(|pp| pp.z).collect(); + let z_invs = C::BaseField::batch_multiplicative_inverse(&zs); + + let mut result = Vec::with_capacity(n); + for i in 0..n { + let Self { x, y, z } = proj_points[i]; + result.push(if z == C::BaseField::ZERO { + AffinePoint::ZERO + } else { + let z_inv = z_invs[i]; + AffinePoint::nonzero(x * z_inv, y * z_inv) + }); + } + result + } + + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/doubling/dbl-2007-bl + pub fn double(&self) -> Self { + let Self { x, y, z } = *self; + if z == C::BaseField::ZERO { + return ProjectivePoint::ZERO; + } + + let xx = x.square(); + let zz = z.square(); + let mut w = xx.triple(); + if C::A.is_nonzero() { + w += C::A * zz; + } + let s = y.double() * z; + let r = y * s; + let rr = r.square(); + let b = (x + r).square() - (xx + rr); + let h = w.square() - b.double(); + let x3 = h * s; + let y3 = w * (b - h) - rr.double(); + let z3 = s.cube(); + Self { + x: x3, + y: y3, + z: z3, + } + } + + pub fn add_slices(a: &[Self], b: &[Self]) -> Vec { + assert_eq!(a.len(), b.len()); + a.iter() + .zip(b.iter()) + .map(|(&a_i, &b_i)| a_i + b_i) + .collect() + } + + pub fn neg(&self) -> Self { + Self { + x: self.x, + y: -self.y, + z: self.z, + } + } +} + +impl PartialEq for ProjectivePoint { + fn eq(&self, other: &Self) -> bool { + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + } = *self; + let ProjectivePoint { + x: x2, + y: y2, + z: z2, + } = *other; + if z1 == C::BaseField::ZERO || z2 == C::BaseField::ZERO { + return z1 == z2; + } + + // We want to compare (x1/z1, y1/z1) == (x2/z2, y2/z2). + // But to avoid field division, it is better to compare (x1*z2, y1*z2) == (x2*z1, y2*z1). + x1 * z2 == x2 * z1 && y1 * z2 == y2 * z1 + } +} + +impl Eq for ProjectivePoint {} + +impl Neg for AffinePoint { + type Output = AffinePoint; + + fn neg(self) -> Self::Output { + let AffinePoint { x, y, zero } = self; + AffinePoint { x, y: -y, zero } + } +} + +impl Neg for ProjectivePoint { + type Output = ProjectivePoint; + + fn neg(self) -> Self::Output { + let ProjectivePoint { x, y, z } = self; + ProjectivePoint { x, y: -y, z } + } +} diff --git a/src/curve/mod.rs b/src/curve/mod.rs new file mode 100644 index 00000000..d31e373e --- /dev/null +++ b/src/curve/mod.rs @@ -0,0 +1,6 @@ +pub mod curve_adds; +pub mod curve_msm; +pub mod curve_multiplication; +pub mod curve_summation; +pub mod curve_types; +pub mod secp256k1; diff --git a/src/curve/secp256k1.rs b/src/curve/secp256k1.rs new file mode 100644 index 00000000..58472eb4 --- /dev/null +++ b/src/curve/secp256k1.rs @@ -0,0 +1,98 @@ +use crate::curve::curve_types::{AffinePoint, Curve}; +use crate::field::field_types::Field; +use crate::field::secp256k1_base::Secp256K1Base; +use crate::field::secp256k1_scalar::Secp256K1Scalar; + +#[derive(Debug, Copy, Clone)] +pub struct Secp256K1; + +impl Curve for Secp256K1 { + type BaseField = Secp256K1Base; + type ScalarField = Secp256K1Scalar; + + const A: Secp256K1Base = Secp256K1Base::ZERO; + const B: Secp256K1Base = Secp256K1Base([7, 0, 0, 0]); + const GENERATOR_AFFINE: AffinePoint = AffinePoint { + x: SECP256K1_GENERATOR_X, + y: SECP256K1_GENERATOR_Y, + zero: false, + }; +} + +// 55066263022277343669578718895168534326250603453777594175500187360389116729240 +const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([ + 0x59F2815B16F81798, + 0x029BFCDB2DCE28D9, + 0x55A06295CE870B07, + 0x79BE667EF9DCBBAC, +]); + +/// 32670510020758816978083085130507043184471273380659243275938904335757337482424 +const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([ + 0x9C47D08FFB10D4B8, + 0xFD17B448A6855419, + 0x5DA4FBFC0E1108A8, + 0x483ADA7726A3C465, +]); + +#[cfg(test)] +mod tests { + use num::BigUint; + + use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; + use crate::curve::secp256k1::Secp256K1; + use crate::field::field_types::Field; + use crate::field::secp256k1_scalar::Secp256K1Scalar; + + #[test] + fn test_generator() { + let g = Secp256K1::GENERATOR_AFFINE; + assert!(g.is_valid()); + + let neg_g = AffinePoint:: { + x: g.x, + y: -g.y, + zero: g.zero, + }; + assert!(neg_g.is_valid()); + } + + #[test] + fn test_naive_multiplication() { + let g = Secp256K1::GENERATOR_PROJECTIVE; + let ten = Secp256K1Scalar::from_canonical_u64(10); + let product = mul_naive(ten, g); + let sum = g + g + g + g + g + g + g + g + g + g; + assert_eq!(product, sum); + } + + #[test] + fn test_g1_multiplication() { + let lhs = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + 1111, 2222, 3333, 4444, 5555, 6666, 7777, 8888, + ])); + assert_eq!( + Secp256K1::convert(lhs) * Secp256K1::GENERATOR_PROJECTIVE, + mul_naive(lhs, Secp256K1::GENERATOR_PROJECTIVE) + ); + } + + /// A simple, somewhat inefficient implementation of multiplication which is used as a reference + /// for correctness. + fn mul_naive( + lhs: Secp256K1Scalar, + rhs: ProjectivePoint, + ) -> ProjectivePoint { + let mut g = rhs; + let mut sum = ProjectivePoint::ZERO; + for limb in lhs.to_biguint().to_u64_digits().iter() { + for j in 0..64 { + if (limb >> j & 1u64) != 0u64 { + sum = sum + g; + } + g = g.double(); + } + } + sum + } +} diff --git a/src/field/cosets.rs b/src/field/cosets.rs index 4ad9ba38..62be67dc 100644 --- a/src/field/cosets.rs +++ b/src/field/cosets.rs @@ -31,8 +31,6 @@ mod tests { #[test] fn distinct_cosets() { - // TODO: Switch to a smaller test field so that collision rejection is likely to occur. - type F = GoldilocksField; const SUBGROUP_BITS: usize = 5; const NUM_SHIFTS: usize = 50; diff --git a/src/field/extension_field/algebra.rs b/src/field/extension_field/algebra.rs index 21438262..93d25de4 100644 --- a/src/field/extension_field/algebra.rs +++ b/src/field/extension_field/algebra.rs @@ -160,12 +160,32 @@ impl, const D: usize> PolynomialCoeffsAlgebra { .fold(ExtensionAlgebra::ZERO, |acc, &c| acc * x + c) } + /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. + pub fn eval_with_powers(&self, powers: &[ExtensionAlgebra]) -> ExtensionAlgebra { + debug_assert_eq!(self.coeffs.len(), powers.len() + 1); + let acc = self.coeffs[0]; + self.coeffs[1..] + .iter() + .zip(powers) + .fold(acc, |acc, (&x, &c)| acc + c * x) + } + pub fn eval_base(&self, x: F) -> ExtensionAlgebra { self.coeffs .iter() .rev() .fold(ExtensionAlgebra::ZERO, |acc, &c| acc.scalar_mul(x) + c) } + + /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. + pub fn eval_base_with_powers(&self, powers: &[F]) -> ExtensionAlgebra { + debug_assert_eq!(self.coeffs.len(), powers.len() + 1); + let acc = self.coeffs[0]; + self.coeffs[1..] + .iter() + .zip(powers) + .fold(acc, |acc, (&x, &c)| acc + x.scalar_mul(c)) + } } #[cfg(test)] diff --git a/src/field/extension_field/mod.rs b/src/field/extension_field/mod.rs index f8322e6d..2ddea4ee 100644 --- a/src/field/extension_field/mod.rs +++ b/src/field/extension_field/mod.rs @@ -1,3 +1,4 @@ +use crate::field::field_types::{Field, PrimeField}; use std::convert::TryInto; use crate::field::field_types::{Field, RichField}; diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index ebad5025..2243612e 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -3,6 +3,7 @@ use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use num::bigint::BigUint; +use num::Integer; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -49,26 +50,28 @@ impl> From for QuadraticExtension { } impl> Field for QuadraticExtension { - type PrimeField = F; - const ZERO: Self = Self([F::ZERO; 2]); const ONE: Self = Self([F::ONE, F::ZERO]); const TWO: Self = Self([F::TWO, F::ZERO]); const NEG_ONE: Self = Self([F::NEG_ONE, F::ZERO]); - const CHARACTERISTIC: u64 = F::CHARACTERISTIC; - // `p^2 - 1 = (p - 1)(p + 1)`. The `p - 1` term has a two-adicity of `F::TWO_ADICITY`. As // long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as // `2(2n + 1)`, which has a 2-adicity of 1. const TWO_ADICITY: usize = F::TWO_ADICITY + 1; + const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); + const BITS: usize = F::BITS * 2; + fn order() -> BigUint { F::order() * F::order() } + fn characteristic() -> BigUint { + F::characteristic() + } #[inline(always)] fn square(&self) -> Self { @@ -99,6 +102,15 @@ impl> Field for QuadraticExtension { )) } + fn from_biguint(n: BigUint) -> Self { + let (high, low) = n.div_rem(&F::order()); + Self([F::from_biguint(low), F::from_biguint(high)]) + } + + fn to_biguint(&self) -> BigUint { + self.0[0].to_biguint() + F::order() * self.0[1].to_biguint() + } + fn from_canonical_u64(n: u64) -> Self { F::from_canonical_u64(n).into() } diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index 001da821..781f79f5 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -4,6 +4,7 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi use num::bigint::BigUint; use num::traits::Pow; +use num::Integer; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -50,27 +51,29 @@ impl> From for QuarticExtension { } impl> Field for QuarticExtension { - type PrimeField = F; - const ZERO: Self = Self([F::ZERO; 4]); const ONE: Self = Self([F::ONE, F::ZERO, F::ZERO, F::ZERO]); const TWO: Self = Self([F::TWO, F::ZERO, F::ZERO, F::ZERO]); const NEG_ONE: Self = Self([F::NEG_ONE, F::ZERO, F::ZERO, F::ZERO]); - const CHARACTERISTIC: u64 = F::ORDER; - // `p^4 - 1 = (p - 1)(p + 1)(p^2 + 1)`. The `p - 1` term has a two-adicity of `F::TWO_ADICITY`. // As long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as // `2(2n + 1)`, which has a 2-adicity of 1. A similar argument can show that `p^2 + 1` also has // a 2-adicity of 1. const TWO_ADICITY: usize = F::TWO_ADICITY + 2; + const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); + const BITS: usize = F::BITS * 4; + fn order() -> BigUint { F::order().pow(4u32) } + fn characteristic() -> BigUint { + F::characteristic() + } #[inline(always)] fn square(&self) -> Self { @@ -104,6 +107,26 @@ impl> Field for QuarticExtension { )) } + fn from_biguint(n: BigUint) -> Self { + let (rest, first) = n.div_rem(&F::order()); + let (rest, second) = rest.div_rem(&F::order()); + let (rest, third) = rest.div_rem(&F::order()); + Self([ + F::from_biguint(first), + F::from_biguint(second), + F::from_biguint(third), + F::from_biguint(rest), + ]) + } + + fn to_biguint(&self) -> BigUint { + let mut result = self.0[3].to_biguint(); + result = result * F::order() + self.0[2].to_biguint(); + result = result * F::order() + self.0[1].to_biguint(); + result = result * F::order() + self.0[0].to_biguint(); + result + } + fn from_canonical_u64(n: u64) -> Self { F::from_canonical_u64(n).into() } diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index d92d3c3f..5517779b 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -1,4 +1,3 @@ -use std::convert::{TryFrom, TryInto}; use std::ops::Range; use crate::field::extension_field::algebra::ExtensionAlgebra; @@ -33,6 +32,7 @@ impl ExtensionTarget { let arr = self.to_target_array(); let k = (F::order() - 1u32) / (D as u64); let z0 = F::Extension::W.exp_biguint(&(k * count as u64)); + #[allow(clippy::needless_collect)] let zs = z0 .powers() .take(D) diff --git a/src/field/fft.rs b/src/field/fft.rs index 96f19857..ba94f6a7 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -5,8 +5,8 @@ use unroll::unroll_for_loops; use crate::field::field_types::Field; use crate::field::packable::Packable; -use crate::field::packed_field::{PackedField, Singleton}; -use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::field::packed_field::PackedField; +use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::{log2_strict, reverse_index_bits}; pub(crate) type FftRootTable = Vec>; @@ -38,7 +38,7 @@ fn fft_dispatch( zero_factor: Option, root_table: Option<&FftRootTable>, ) -> Vec { - let computed_root_table = if let Some(_) = root_table { + let computed_root_table = if root_table.is_some() { None } else { Some(fft_root_table(input.len())) @@ -98,12 +98,12 @@ pub fn ifft_with_options( /// Generic FFT implementation that works with both scalar and packed inputs. #[unroll_for_loops] fn fft_classic_simd( - values: &mut [P::FieldType], + values: &mut [P::Scalar], r: usize, lg_n: usize, - root_table: &FftRootTable, + root_table: &FftRootTable, ) { - let lg_packed_width = P::LOG2_WIDTH; // 0 when P is a scalar. + let lg_packed_width = log2_strict(P::WIDTH); // 0 when P is a scalar. let packed_values = P::pack_slice_mut(values); let packed_n = packed_values.len(); debug_assert!(packed_n == 1 << (lg_n - lg_packed_width)); @@ -121,19 +121,18 @@ fn fft_classic_simd( let half_m = 1 << lg_half_m; // Set omega to root_table[lg_half_m][0..half_m] but repeated. - let mut omega_vec = P::zero().to_vec(); - for j in 0..omega_vec.len() { - omega_vec[j] = root_table[lg_half_m][j % half_m]; + let mut omega = P::ZERO; + for (j, omega_j) in omega.as_slice_mut().iter_mut().enumerate() { + *omega_j = root_table[lg_half_m][j % half_m]; } - let omega = P::from_slice(&omega_vec[..]); for k in (0..packed_n).step_by(2) { // We have two vectors and want to do math on pairs of adjacent elements (or for // lg_half_m > 0, pairs of adjacent blocks of elements). .interleave does the // appropriate shuffling and is its own inverse. - let (u, v) = packed_values[k].interleave(packed_values[k + 1], lg_half_m); + let (u, v) = packed_values[k].interleave(packed_values[k + 1], half_m); let t = omega * v; - (packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, lg_half_m); + (packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, half_m); } } } @@ -197,13 +196,13 @@ pub(crate) fn fft_classic(input: &[F], r: usize, root_table: &FftRootT } } - let lg_packed_width = ::PackedType::LOG2_WIDTH; + let lg_packed_width = log2_strict(::Packing::WIDTH); if lg_n <= lg_packed_width { // Need the slice to be at least the width of two packed vectors for the vectorized version // to work. Do this tiny problem in scalar. - fft_classic_simd::>(&mut values[..], r, lg_n, &root_table); + fft_classic_simd::(&mut values[..], r, lg_n, root_table); } else { - fft_classic_simd::<::PackedType>(&mut values[..], r, lg_n, &root_table); + fft_classic_simd::<::Packing>(&mut values[..], r, lg_n, root_table); } values } @@ -213,19 +212,23 @@ mod tests { use crate::field::fft::{fft, fft_with_options, ifft}; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; - use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; + use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::{log2_ceil, log2_strict}; #[test] fn fft_and_ifft() { type F = GoldilocksField; - let degree = 200; - let degree_padded = log2_ceil(degree); - let mut coefficients = Vec::new(); - for i in 0..degree { - coefficients.push(F::from_canonical_usize(i * 1337 % 100)); - } - let coefficients = PolynomialCoeffs::new_padded(coefficients); + let degree = 200usize; + let degree_padded = degree.next_power_of_two(); + + // Create a vector of coeffs; the first degree of them are + // "random", the last degree_padded-degree of them are zero. + let coeffs = (0..degree) + .map(|i| F::from_canonical_usize(i * 1337 % 100)) + .chain(std::iter::repeat(F::ZERO).take(degree_padded - degree)) + .collect::>(); + assert_eq!(coeffs.len(), degree_padded); + let coefficients = PolynomialCoeffs { coeffs }; let points = fft(&coefficients); assert_eq!(points, evaluate_naive(&coefficients)); @@ -263,7 +266,7 @@ mod tests { let values = subgroup .into_iter() - .map(|x| evaluate_at_naive(&coefficients, x)) + .map(|x| evaluate_at_naive(coefficients, x)) .collect(); PolynomialValues::new(values) } @@ -272,8 +275,8 @@ mod tests { let mut sum = F::ZERO; let mut point_power = F::ONE; for &c in &coefficients.coeffs { - sum = sum + c * point_power; - point_power = point_power * point; + sum += c * point_power; + point_power *= point; } sum } diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index f422d810..b4ee0595 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -13,12 +13,15 @@ macro_rules! test_field_arithmetic { #[test] fn batch_inversion() { - let xs = (1..=3) - .map(|i| <$field>::from_canonical_u64(i)) - .collect::>(); - let invs = <$field>::batch_multiplicative_inverse(&xs); - for (x, inv) in xs.into_iter().zip(invs) { - assert_eq!(x * inv, <$field>::ONE); + for n in 0..20 { + let xs = (1..=n as u64) + .map(|i| <$field>::from_canonical_u64(i)) + .collect::>(); + let invs = <$field>::batch_multiplicative_inverse(&xs); + assert_eq!(invs.len(), n); + for (x, inv) in xs.into_iter().zip(invs) { + assert_eq!(x * inv, <$field>::ONE); + } } } @@ -81,10 +84,24 @@ macro_rules! test_field_arithmetic { assert_eq!(base.exp_biguint(&pow), base.exp_biguint(&big_pow)); assert_ne!(base.exp_biguint(&pow), base.exp_biguint(&big_pow_wrong)); } + + #[test] + fn inverses() { + type F = $field; + + let x = F::rand(); + let x1 = x.inverse(); + let x2 = x1.inverse(); + let x3 = x2.inverse(); + + assert_eq!(x, x2); + assert_eq!(x1, x3); + } } }; } +#[allow(clippy::eq_op)] pub(crate) fn test_add_neg_sub_mul, const D: usize>() { let x = BF::Extension::rand(); let y = BF::Extension::rand(); diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 28aa1e97..68c42dc5 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -1,11 +1,10 @@ -use std::convert::TryInto; use std::fmt::{Debug, Display}; use std::hash::Hash; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use num::bigint::BigUint; -use num::{Integer, One, Zero}; +use num::{Integer, One, ToPrimitive, Zero}; use rand::Rng; use serde::de::DeserializeOwned; use serde::Serialize; @@ -43,24 +42,28 @@ pub trait Field: + Serialize + DeserializeOwned { - type PrimeField: PrimeField; - const ZERO: Self; const ONE: Self; const TWO: Self; const NEG_ONE: Self; - const CHARACTERISTIC: u64; - /// The 2-adicity of this field's multiplicative group. const TWO_ADICITY: usize; + /// The field's characteristic and it's 2-adicity. + /// Set to `None` when the characteristic doesn't fit in a u64. + const CHARACTERISTIC_TWO_ADICITY: usize; + /// Generator of the entire multiplicative group, i.e. all non-zero elements. const MULTIPLICATIVE_GROUP_GENERATOR: Self; /// Generator of a multiplicative subgroup of order `2^TWO_ADICITY`. const POWER_OF_TWO_GENERATOR: Self; + /// The bit length of the field order. + const BITS: usize; + fn order() -> BigUint; + fn characteristic() -> BigUint; #[inline] fn is_zero(&self) -> bool { @@ -92,6 +95,10 @@ pub trait Field: self.square() * *self } + fn triple(&self) -> Self { + *self * (Self::ONE + Self::TWO) + } + /// Compute the multiplicative inverse of this field element. fn try_inverse(&self) -> Option; @@ -103,34 +110,91 @@ pub trait Field: // This is Montgomery's trick. At a high level, we invert the product of the given field // elements, then derive the individual inverses from that via multiplication. + // The usual Montgomery trick involves calculating an array of cumulative products, + // resulting in a long dependency chain. To increase instruction-level parallelism, we + // compute WIDTH separate cumulative product arrays that only meet at the end. + + // Higher WIDTH increases instruction-level parallelism, but too high a value will cause us + // to run out of registers. + const WIDTH: usize = 4; + // JN note: WIDTH is 4. The code is specialized to this value and will need + // modification if it is changed. I tried to make it more generic, but Rust's const + // generics are not yet good enough. + + // Handle special cases. Paradoxically, below is repetitive but concise. + // The branches should be very predictable. let n = x.len(); if n == 0 { return Vec::new(); - } - if n == 1 { + } else if n == 1 { return vec![x[0].inverse()]; + } else if n == 2 { + let x01 = x[0] * x[1]; + let x01inv = x01.inverse(); + return vec![x01inv * x[1], x01inv * x[0]]; + } else if n == 3 { + let x01 = x[0] * x[1]; + let x012 = x01 * x[2]; + let x012inv = x012.inverse(); + let x01inv = x012inv * x[2]; + return vec![x01inv * x[1], x01inv * x[0], x012inv * x01]; } + debug_assert!(n >= WIDTH); - // Fill buf with cumulative product of x. - let mut buf = Vec::with_capacity(n); - let mut cumul_prod = x[0]; - buf.push(cumul_prod); - for i in 1..n { - cumul_prod *= x[i]; - buf.push(cumul_prod); + // Buf is reused for a few things to save allocations. + // Fill buf with cumulative product of x, only taking every 4th value. Concretely, buf will + // be [ + // x[0], x[1], x[2], x[3], + // x[0] * x[4], x[1] * x[5], x[2] * x[6], x[3] * x[7], + // x[0] * x[4] * x[8], x[1] * x[5] * x[9], x[2] * x[6] * x[10], x[3] * x[7] * x[11], + // ... + // ]. + // If n is not a multiple of WIDTH, the result is truncated from the end. For example, + // for n == 5, we get [x[0], x[1], x[2], x[3], x[0] * x[4]]. + let mut buf: Vec = Vec::with_capacity(n); + // cumul_prod holds the last WIDTH elements of buf. This is redundant, but it's how we + // convince LLVM to keep the values in the registers. + let mut cumul_prod: [Self; WIDTH] = x[..WIDTH].try_into().unwrap(); + buf.extend(cumul_prod); + for (i, &xi) in x[WIDTH..].iter().enumerate() { + cumul_prod[i % WIDTH] *= xi; + buf.push(cumul_prod[i % WIDTH]); } + debug_assert_eq!(buf.len(), n); - // At this stage buf contains the the cumulative product of x. We reuse the buffer for - // efficiency. At the end of the loop, it is filled with inverses of x. - let mut a_inv = cumul_prod.inverse(); - buf[n - 1] = buf[n - 2] * a_inv; - for i in (1..n - 1).rev() { - a_inv = x[i + 1] * a_inv; - // buf[i - 1] has not been written to by this loop, so it equals x[0] * ... x[n - 1]. - buf[i] = buf[i - 1] * a_inv; + let mut a_inv = { + // This is where the four dependency chains meet. + // Take the last four elements of buf and invert them all. + let c01 = cumul_prod[0] * cumul_prod[1]; + let c23 = cumul_prod[2] * cumul_prod[3]; + let c0123 = c01 * c23; + let c0123inv = c0123.inverse(); + let c01inv = c0123inv * c23; + let c23inv = c0123inv * c01; + [ + c01inv * cumul_prod[1], + c01inv * cumul_prod[0], + c23inv * cumul_prod[3], + c23inv * cumul_prod[2], + ] + }; + + for i in (WIDTH..n).rev() { + // buf[i - WIDTH] has not been written to by this loop, so it equals + // x[i % WIDTH] * x[i % WIDTH + WIDTH] * ... * x[i - WIDTH]. + buf[i] = buf[i - WIDTH] * a_inv[i % WIDTH]; // buf[i] now holds the inverse of x[i]. + a_inv[i % WIDTH] *= x[i]; } - buf[0] = x[1] * a_inv; + for i in (0..WIDTH).rev() { + buf[i] = a_inv[i]; + } + + for (&bi, &xi) in buf.iter().zip(x) { + // Sanity check only. + debug_assert_eq!(bi * xi, Self::ONE); + } + buf } @@ -142,29 +206,31 @@ pub trait Field: // exp exceeds t, we repeatedly multiply by 2^-t and reduce // exp until it's in the right range. - let p = Self::CHARACTERISTIC; + if let Some(p) = Self::characteristic().to_u64() { + // NB: The only reason this is split into two cases is to save + // the multiplication (and possible calculation of + // inverse_2_pow_adicity) in the usual case that exp <= + // TWO_ADICITY. Can remove the branch and simplify if that + // saving isn't worth it. - // NB: The only reason this is split into two cases is to save - // the multiplication (and possible calculation of - // inverse_2_pow_adicity) in the usual case that exp <= - // TWO_ADICITY. Can remove the branch and simplify if that - // saving isn't worth it. + if exp > Self::CHARACTERISTIC_TWO_ADICITY { + // NB: This should be a compile-time constant + let inverse_2_pow_adicity: Self = + Self::from_canonical_u64(p - ((p - 1) >> Self::CHARACTERISTIC_TWO_ADICITY)); - if exp > Self::PrimeField::TWO_ADICITY { - // NB: This should be a compile-time constant - let inverse_2_pow_adicity: Self = - Self::from_canonical_u64(p - ((p - 1) >> Self::PrimeField::TWO_ADICITY)); + let mut res = inverse_2_pow_adicity; + let mut e = exp - Self::CHARACTERISTIC_TWO_ADICITY; - let mut res = inverse_2_pow_adicity; - let mut e = exp - Self::PrimeField::TWO_ADICITY; - - while e > Self::PrimeField::TWO_ADICITY { - res *= inverse_2_pow_adicity; - e -= Self::PrimeField::TWO_ADICITY; + while e > Self::CHARACTERISTIC_TWO_ADICITY { + res *= inverse_2_pow_adicity; + e -= Self::CHARACTERISTIC_TWO_ADICITY; + } + res * Self::from_canonical_u64(p - ((p - 1) >> e)) + } else { + Self::from_canonical_u64(p - ((p - 1) >> exp)) } - res * Self::from_canonical_u64(p - ((p - 1) >> e)) } else { - Self::from_canonical_u64(p - ((p - 1) >> exp)) + Self::TWO.inverse().exp_u64(exp as u64) } } @@ -206,6 +272,11 @@ pub trait Field: subgroup.into_iter().map(|x| x * shift).collect() } + // TODO: move these to a new `PrimeField` trait (for all prime fields, not just 64-bit ones) + fn from_biguint(n: BigUint) -> Self; + + fn to_biguint(&self) -> BigUint; + fn from_canonical_u64(n: u64) -> Self; fn from_canonical_u32(n: u32) -> Self { @@ -274,7 +345,7 @@ pub trait Field: } fn kth_root_u64(&self, k: u64) -> Self { - let p = Self::order().clone(); + let p = Self::order(); let p_minus_1 = &p - 1u32; debug_assert!( Self::is_monomial_permutation_u64(k), @@ -356,6 +427,7 @@ pub trait PrimeField: Field { unsafe { self.sub_canonical_u64(1) } } + /// # Safety /// Equivalent to *self + Self::from_canonical_u64(rhs), but may be cheaper. The caller must /// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this /// precondition is not met. It is marked unsafe for this reason. @@ -365,6 +437,7 @@ pub trait PrimeField: Field { *self + Self::from_canonical_u64(rhs) } + /// # Safety /// Equivalent to *self - Self::from_canonical_u64(rhs), but may be cheaper. The caller must /// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this /// precondition is not met. It is marked unsafe for this reason. diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index 45164506..9e93d1f1 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -4,7 +4,7 @@ use std::hash::{Hash, Hasher}; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use num::BigUint; +use num::{BigUint, Integer}; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -62,15 +62,13 @@ impl Debug for GoldilocksField { } impl Field for GoldilocksField { - type PrimeField = Self; - const ZERO: Self = Self(0); const ONE: Self = Self(1); const TWO: Self = Self(2); const NEG_ONE: Self = Self(Self::ORDER - 1); - const CHARACTERISTIC: u64 = Self::ORDER; const TWO_ADICITY: usize = 32; + const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY; // Sage: `g = GF(p).multiplicative_generator()` const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(7); @@ -82,15 +80,28 @@ impl Field for GoldilocksField { // ``` const POWER_OF_TWO_GENERATOR: Self = Self(1753635133440165772); + const BITS: usize = 64; + fn order() -> BigUint { Self::ORDER.into() } + fn characteristic() -> BigUint { + Self::order() + } #[inline(always)] fn try_inverse(&self) -> Option { try_inverse_u64(self) } + fn from_biguint(n: BigUint) -> Self { + Self(n.mod_floor(&Self::order()).to_u64_digits()[0]) + } + + fn to_biguint(&self) -> BigUint { + self.to_canonical_u64().into() + } + #[inline] fn from_canonical_u64(n: u64) -> Self { debug_assert!(n < Self::ORDER); @@ -312,6 +323,7 @@ impl RichField for GoldilocksField {} #[inline(always)] #[cfg(target_arch = "x86_64")] unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { + use std::arch::asm; let res_wrapped: u64; let adjustment: u64; asm!( @@ -352,6 +364,7 @@ unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { #[inline(always)] #[cfg(target_arch = "x86_64")] unsafe fn sub_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { + use std::arch::asm; let res_wrapped: u64; let adjustment: u64; asm!( diff --git a/src/field/interpolation.rs b/src/field/interpolation.rs index c4a49fe1..ad3ddf72 100644 --- a/src/field/interpolation.rs +++ b/src/field/interpolation.rs @@ -1,6 +1,6 @@ use crate::field::fft::ifft; use crate::field::field_types::Field; -use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::log2_ceil; /// Computes the unique degree < n interpolant of an arbitrary list of n (point, value) pairs. @@ -80,7 +80,7 @@ mod tests { use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; - use crate::polynomial::polynomial::PolynomialCoeffs; + use crate::polynomial::PolynomialCoeffs; #[test] fn interpolant_random() { diff --git a/src/field/mod.rs b/src/field/mod.rs index 5ed64a54..74e0fbf4 100644 --- a/src/field/mod.rs +++ b/src/field/mod.rs @@ -7,7 +7,8 @@ pub(crate) mod interpolation; mod inversion; pub(crate) mod packable; pub(crate) mod packed_field; -pub mod secp256k1; +pub mod secp256k1_base; +pub mod secp256k1_scalar; #[cfg(target_feature = "avx2")] pub(crate) mod packed_avx2; diff --git a/src/field/packable.rs b/src/field/packable.rs index 94a9c056..a3f96197 100644 --- a/src/field/packable.rs +++ b/src/field/packable.rs @@ -1,18 +1,18 @@ use crate::field::field_types::Field; -use crate::field::packed_field::{PackedField, Singleton}; +use crate::field::packed_field::PackedField; /// Points us to the default packing for a particular field. There may me multiple choices of -/// PackedField for a particular Field (e.g. Singleton works for all fields), but this is the +/// PackedField for a particular Field (e.g. every Field is also a PackedField), but this is the /// recommended one. The recommended packing varies by target_arch and target_feature. pub trait Packable: Field { - type PackedType: PackedField; + type Packing: PackedField; } impl Packable for F { - default type PackedType = Singleton; + default type Packing = Self; } #[cfg(target_feature = "avx2")] impl Packable for crate::field::goldilocks_field::GoldilocksField { - type PackedType = crate::field::packed_avx2::PackedGoldilocksAVX2; + type Packing = crate::field::packed_avx2::PackedGoldilocksAvx2; } diff --git a/src/field/packed_avx2/packed_prime_field.rs b/src/field/packed_avx2/avx2_prime_field.rs similarity index 58% rename from src/field/packed_avx2/packed_prime_field.rs rename to src/field/packed_avx2/avx2_prime_field.rs index ed87f347..b42814c2 100644 --- a/src/field/packed_avx2/packed_prime_field.rs +++ b/src/field/packed_avx2/avx2_prime_field.rs @@ -2,20 +2,20 @@ use core::arch::x86_64::*; use std::fmt; use std::fmt::{Debug, Formatter}; use std::iter::{Product, Sum}; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use crate::field::field_types::PrimeField; use crate::field::packed_avx2::common::{ - add_no_canonicalize_64_64s_s, epsilon, field_order, ReducibleAVX2, + add_no_canonicalize_64_64s_s, epsilon, field_order, shift, ReducibleAvx2, }; use crate::field::packed_field::PackedField; -// PackedPrimeField wraps an array of four u64s, with the new and get methods to convert that +// Avx2PrimeField wraps an array of four u64s, with the new and get methods to convert that // array to and from __m256i, which is the type we actually operate on. This indirection is a -// terrible trick to change PackedPrimeField's alignment. -// We'd like to be able to cast slices of PrimeField to slices of PackedPrimeField. Rust +// terrible trick to change Avx2PrimeField's alignment. +// We'd like to be able to cast slices of PrimeField to slices of Avx2PrimeField. Rust // aligns __m256i to 32 bytes but PrimeField has a lower alignment. That alignment extends to -// PackedPrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is +// Avx2PrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is // important for Rust not to assume 32-byte alignment, so we cannot wrap __m256i directly. // There are two versions of vectorized load/store instructions on x86: aligned (vmovaps and // friends) and unaligned (vmovups etc.). The difference between them is that aligned loads and @@ -23,12 +23,12 @@ use crate::field::packed_field::PackedField; // were faster, and although this is no longer the case, compilers prefer the aligned versions if // they know that the address is aligned. Using aligned instructions on unaligned addresses leads to // bugs that can be frustrating to diagnose. Hence, we can't have Rust assuming alignment, and -// therefore PackedPrimeField wraps [F; 4] and not __m256i. +// therefore Avx2PrimeField wraps [F; 4] and not __m256i. #[derive(Copy, Clone)] #[repr(transparent)] -pub struct PackedPrimeField(pub [F; 4]); +pub struct Avx2PrimeField(pub [F; 4]); -impl PackedPrimeField { +impl Avx2PrimeField { #[inline] fn new(x: __m256i) -> Self { let mut obj = Self([F::ZERO; 4]); @@ -43,84 +43,111 @@ impl PackedPrimeField { let ptr = (&self.0).as_ptr().cast::<__m256i>(); unsafe { _mm256_loadu_si256(ptr) } } - - /// Addition that assumes x + y < 2^64 + F::ORDER. May return incorrect results if this - /// condition is not met, hence it is marked unsafe. - #[inline] - pub unsafe fn add_canonical_u64(&self, rhs: __m256i) -> Self { - Self::new(add_canonical_u64::(self.get(), rhs)) - } } -impl Add for PackedPrimeField { +impl Add for Avx2PrimeField { type Output = Self; #[inline] fn add(self, rhs: Self) -> Self { Self::new(unsafe { add::(self.get(), rhs.get()) }) } } -impl Add for PackedPrimeField { +impl Add for Avx2PrimeField { type Output = Self; #[inline] fn add(self, rhs: F) -> Self { - self + Self::broadcast(rhs) + self + Self::from(rhs) } } -impl AddAssign for PackedPrimeField { +impl Add> for as PackedField>::Scalar { + type Output = Avx2PrimeField; + #[inline] + fn add(self, rhs: Self::Output) -> Self::Output { + Self::Output::from(self) + rhs + } +} +impl AddAssign for Avx2PrimeField { #[inline] fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } -impl AddAssign for PackedPrimeField { +impl AddAssign for Avx2PrimeField { #[inline] fn add_assign(&mut self, rhs: F) { *self = *self + rhs; } } -impl Debug for PackedPrimeField { +impl Debug for Avx2PrimeField { #[inline] fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "({:?})", self.get()) } } -impl Default for PackedPrimeField { +impl Default for Avx2PrimeField { #[inline] fn default() -> Self { - Self::zero() + Self::ZERO } } -impl Mul for PackedPrimeField { +impl Div for Avx2PrimeField { + type Output = Self; + #[inline] + fn div(self, rhs: F) -> Self { + self * rhs.inverse() + } +} +impl DivAssign for Avx2PrimeField { + #[inline] + fn div_assign(&mut self, rhs: F) { + *self *= rhs.inverse(); + } +} + +impl From for Avx2PrimeField { + fn from(x: F) -> Self { + Self([x; 4]) + } +} + +impl Mul for Avx2PrimeField { type Output = Self; #[inline] fn mul(self, rhs: Self) -> Self { Self::new(unsafe { mul::(self.get(), rhs.get()) }) } } -impl Mul for PackedPrimeField { +impl Mul for Avx2PrimeField { type Output = Self; #[inline] fn mul(self, rhs: F) -> Self { - self * Self::broadcast(rhs) + self * Self::from(rhs) } } -impl MulAssign for PackedPrimeField { +impl Mul> for as PackedField>::Scalar { + type Output = Avx2PrimeField; + #[inline] + fn mul(self, rhs: Avx2PrimeField) -> Self::Output { + Self::Output::from(self) * rhs + } +} +impl MulAssign for Avx2PrimeField { #[inline] fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } -impl MulAssign for PackedPrimeField { +impl MulAssign for Avx2PrimeField { #[inline] fn mul_assign(&mut self, rhs: F) { *self = *self * rhs; } } -impl Neg for PackedPrimeField { +impl Neg for Avx2PrimeField { type Output = Self; #[inline] fn neg(self) -> Self { @@ -128,52 +155,59 @@ impl Neg for PackedPrimeField { } } -impl Product for PackedPrimeField { +impl Product for Avx2PrimeField { #[inline] fn product>(iter: I) -> Self { - iter.reduce(|x, y| x * y).unwrap_or(Self::one()) + iter.reduce(|x, y| x * y).unwrap_or(Self::ONE) } } -impl PackedField for PackedPrimeField { - const LOG2_WIDTH: usize = 2; +unsafe impl PackedField for Avx2PrimeField { + const WIDTH: usize = 4; - type FieldType = F; + type Scalar = F; + type PackedPrimeField = Avx2PrimeField; + + const ZERO: Self = Self([F::ZERO; 4]); + const ONE: Self = Self([F::ONE; 4]); #[inline] - fn broadcast(x: F) -> Self { - Self([x; 4]) - } - - #[inline] - fn from_arr(arr: [F; Self::WIDTH]) -> Self { + fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self { Self(arr) } #[inline] - fn to_arr(&self) -> [F; Self::WIDTH] { + fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] { self.0 } #[inline] - fn from_slice(slice: &[F]) -> Self { - assert!(slice.len() == 4); - Self([slice[0], slice[1], slice[2], slice[3]]) + fn from_slice(slice: &[Self::Scalar]) -> &Self { + assert_eq!(slice.len(), Self::WIDTH); + unsafe { &*slice.as_ptr().cast() } + } + #[inline] + fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self { + assert_eq!(slice.len(), Self::WIDTH); + unsafe { &mut *slice.as_mut_ptr().cast() } + } + #[inline] + fn as_slice(&self) -> &[Self::Scalar] { + &self.0[..] + } + #[inline] + fn as_slice_mut(&mut self) -> &mut [Self::Scalar] { + &mut self.0[..] } #[inline] - fn to_vec(&self) -> Vec { - self.0.into() - } - - #[inline] - fn interleave(&self, other: Self, r: usize) -> (Self, Self) { + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) { let (v0, v1) = (self.get(), other.get()); - let (res0, res1) = match r { - 0 => unsafe { interleave0(v0, v1) }, + let (res0, res1) = match block_len { 1 => unsafe { interleave1(v0, v1) }, - 2 => (v0, v1), - _ => panic!("r cannot be more than LOG2_WIDTH"), + 2 => unsafe { interleave2(v0, v1) }, + 4 => (v0, v1), + _ => panic!("unsupported block_len"), }; (Self::new(res0), Self::new(res1)) } @@ -184,47 +218,47 @@ impl PackedField for PackedPrimeField { } } -impl Sub for PackedPrimeField { +impl Sub for Avx2PrimeField { type Output = Self; #[inline] fn sub(self, rhs: Self) -> Self { Self::new(unsafe { sub::(self.get(), rhs.get()) }) } } -impl Sub for PackedPrimeField { +impl Sub for Avx2PrimeField { type Output = Self; #[inline] fn sub(self, rhs: F) -> Self { - self - Self::broadcast(rhs) + self - Self::from(rhs) } } -impl SubAssign for PackedPrimeField { +impl Sub> for as PackedField>::Scalar { + type Output = Avx2PrimeField; + #[inline] + fn sub(self, rhs: Avx2PrimeField) -> Self::Output { + Self::Output::from(self) - rhs + } +} +impl SubAssign for Avx2PrimeField { #[inline] fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } -impl SubAssign for PackedPrimeField { +impl SubAssign for Avx2PrimeField { #[inline] fn sub_assign(&mut self, rhs: F) { *self = *self - rhs; } } -impl Sum for PackedPrimeField { +impl Sum for Avx2PrimeField { #[inline] fn sum>(iter: I) -> Self { - iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) + iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO) } } -const SIGN_BIT: u64 = 1 << 63; - -#[inline] -unsafe fn sign_bit() -> __m256i { - _mm256_set1_epi64x(SIGN_BIT as i64) -} - // Resources: // 1. Intel Intrinsics Guide for explanation of each intrinsic: // https://software.intel.com/sites/landingpage/IntrinsicsGuide/ @@ -274,12 +308,6 @@ unsafe fn sign_bit() -> __m256i { // Notice that the above 3-value addition still only requires two calls to shift, just like our // 2-value addition. -/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. above). -#[inline] -unsafe fn shift(x: __m256i) -> __m256i { - _mm256_xor_si256(x, sign_bit()) -} - /// Convert to canonical representation. /// The argument is assumed to be shifted by 1 << 63 (i.e. x_s = x + 1<<63, where x is the field /// value). The returned value is similarly shifted by 1 << 63 (i.e. we return y_s = y + (1<<63), @@ -293,14 +321,6 @@ unsafe fn canonicalize_s(x_s: __m256i) -> __m256i { _mm256_add_epi64(x_s, wrapback_amt) } -/// Addition that assumes x + y < 2^64 + F::ORDER. -#[inline] -unsafe fn add_canonical_u64(x: __m256i, y: __m256i) -> __m256i { - let y_s = shift(y); - let res_s = add_no_canonicalize_64_64s_s::(x, y_s); - shift(res_s) -} - #[inline] unsafe fn add(x: __m256i, y: __m256i) -> __m256i { let y_s = shift(y); @@ -326,78 +346,94 @@ unsafe fn neg(y: __m256i) -> __m256i { _mm256_sub_epi64(shift(field_order::()), canonicalize_s::(y_s)) } -/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.5x slower than the +/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.33x slower than the /// scalar instruction, but may be worth it if we want our data to live in vector registers. #[inline] -unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) { - let x_hi = _mm256_srli_epi64(x, 32); - let y_hi = _mm256_srli_epi64(y, 32); +unsafe fn mul64_64(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + // We want to move the high 32 bits to the low position. The multiplication instruction ignores + // the high 32 bits, so it's ok to just duplicate it into the low position. This duplication can + // be done on port 5; bitshifts run on ports 0 and 1, competing with multiplication. + // This instruction is only provided for 32-bit floats, not integers. Idk why Intel makes the + // distinction; the casts are free and it guarantees that the exact bit pattern is preserved. + // Using a swizzle instruction of the wrong domain (float vs int) does not increase latency + // since Haswell. + let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))); + let y_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(y))); + + // All four pairwise multiplications let mul_ll = _mm256_mul_epu32(x, y); let mul_lh = _mm256_mul_epu32(x, y_hi); let mul_hl = _mm256_mul_epu32(x_hi, y); let mul_hh = _mm256_mul_epu32(x_hi, y_hi); - let res_lo0_s = shift(mul_ll); - let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 32)); - let res_lo2_s = _mm256_add_epi32(res_lo1_s, _mm256_slli_epi64(mul_hl, 32)); + // Bignum addition + // Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow. + let mul_ll_hi = _mm256_srli_epi64::<32>(mul_ll); + let t0 = _mm256_add_epi64(mul_hl, mul_ll_hi); + // Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow. + // Also, extract high 32 bits of t0 and add to mul_hh. + let t0_lo = _mm256_and_si256(t0, _mm256_set1_epi64x(u32::MAX.into())); + let t0_hi = _mm256_srli_epi64::<32>(t0); + let t1 = _mm256_add_epi64(mul_lh, t0_lo); + let t2 = _mm256_add_epi64(mul_hh, t0_hi); + // Lastly, extract the high 32 bits of t1 and add to t2. + let t1_hi = _mm256_srli_epi64::<32>(t1); + let res_hi = _mm256_add_epi64(t2, t1_hi); - // cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on - // overflow and must be subtracted, not added. - let carry0 = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); - let carry1 = _mm256_cmpgt_epi64(res_lo1_s, res_lo2_s); + // Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high + // position). + let t1_lo = _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(t1))); + let res_lo = _mm256_blend_epi32::<0xaa>(mul_ll, t1_lo); - let res_hi0 = mul_hh; - let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 32)); - let res_hi2 = _mm256_add_epi64(res_hi1, _mm256_srli_epi64(mul_hl, 32)); - let res_hi3 = _mm256_sub_epi64(res_hi2, carry0); - let res_hi4 = _mm256_sub_epi64(res_hi3, carry1); - - (res_hi4, res_lo2_s) + (res_hi, res_lo) } /// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction. #[inline] -unsafe fn square64_s(x: __m256i) -> (__m256i, __m256i) { - let x_hi = _mm256_srli_epi64(x, 32); +unsafe fn square64(x: __m256i) -> (__m256i, __m256i) { + // Get high 32 bits of x. See comment in mul64_64_s. + let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))); + + // All pairwise multiplications. let mul_ll = _mm256_mul_epu32(x, x); let mul_lh = _mm256_mul_epu32(x, x_hi); let mul_hh = _mm256_mul_epu32(x_hi, x_hi); - let res_lo0_s = shift(mul_ll); - let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 33)); + // Bignum addition, but mul_lh is shifted by 33 bits (not 32). + let mul_ll_hi = _mm256_srli_epi64::<33>(mul_ll); + let t0 = _mm256_add_epi64(mul_lh, mul_ll_hi); + let t0_hi = _mm256_srli_epi64::<31>(t0); + let res_hi = _mm256_add_epi64(mul_hh, t0_hi); - // cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on - // overflow and must be subtracted, not added. - let carry = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); + // Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high + // position). + let mul_lh_lo = _mm256_slli_epi64::<33>(mul_lh); + let res_lo = _mm256_add_epi64(mul_ll, mul_lh_lo); - let res_hi0 = mul_hh; - let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 31)); - let res_hi2 = _mm256_sub_epi64(res_hi1, carry); - - (res_hi2, res_lo1_s) + (res_hi, res_lo) } /// Multiply two integers modulo FIELD_ORDER. #[inline] -unsafe fn mul(x: __m256i, y: __m256i) -> __m256i { - shift(F::reduce128s_s(mul64_64_s(x, y))) +unsafe fn mul(x: __m256i, y: __m256i) -> __m256i { + F::reduce128(mul64_64(x, y)) } /// Square an integer modulo FIELD_ORDER. #[inline] -unsafe fn square(x: __m256i) -> __m256i { - shift(F::reduce128s_s(square64_s(x))) +unsafe fn square(x: __m256i) -> __m256i { + F::reduce128(square64(x)) } #[inline] -unsafe fn interleave0(x: __m256i, y: __m256i) -> (__m256i, __m256i) { +unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) { let a = _mm256_unpacklo_epi64(x, y); let b = _mm256_unpackhi_epi64(x, y); (a, b) } #[inline] -unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) { +unsafe fn interleave2(x: __m256i, y: __m256i) -> (__m256i, __m256i) { let y_lo = _mm256_castsi256_si128(y); // This has 0 cost. // 1 places y_lo in the high half of x; 0 would place it in the lower half. diff --git a/src/field/packed_avx2/common.rs b/src/field/packed_avx2/common.rs index 97674a17..48f9524d 100644 --- a/src/field/packed_avx2/common.rs +++ b/src/field/packed_avx2/common.rs @@ -2,8 +2,22 @@ use core::arch::x86_64::*; use crate::field::field_types::PrimeField; -pub trait ReducibleAVX2: PrimeField { - unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i; +pub trait ReducibleAvx2: PrimeField { + unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i; +} + +const SIGN_BIT: u64 = 1 << 63; + +#[inline] +unsafe fn sign_bit() -> __m256i { + _mm256_set1_epi64x(SIGN_BIT as i64) +} + +/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. in +/// packed_prime_field.rs). +#[inline] +pub unsafe fn shift(x: __m256i) -> __m256i { + _mm256_xor_si256(x, sign_bit()) } #[inline] diff --git a/src/field/packed_avx2/goldilocks.rs b/src/field/packed_avx2/goldilocks.rs index 2cea1767..954516b8 100644 --- a/src/field/packed_avx2/goldilocks.rs +++ b/src/field/packed_avx2/goldilocks.rs @@ -2,19 +2,21 @@ use core::arch::x86_64::*; use crate::field::goldilocks_field::GoldilocksField; use crate::field::packed_avx2::common::{ - add_no_canonicalize_64_64s_s, epsilon, sub_no_canonicalize_64s_64_s, ReducibleAVX2, + add_no_canonicalize_64_64s_s, epsilon, shift, sub_no_canonicalize_64s_64_s, ReducibleAvx2, }; /// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is /// similarly shifted. -impl ReducibleAVX2 for GoldilocksField { +impl ReducibleAvx2 for GoldilocksField { #[inline] - unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i { - let (hi0, lo0_s) = x_s; + unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i { + let (hi0, lo0) = x; + let lo0_s = shift(lo0); let hi_hi0 = _mm256_srli_epi64(hi0, 32); let lo1_s = sub_no_canonicalize_64s_64_s::(lo0_s, hi_hi0); let t1 = _mm256_mul_epu32(hi0, epsilon::()); let lo2_s = add_no_canonicalize_64_64s_s::(t1, lo1_s); - lo2_s + let lo2 = shift(lo2_s); + lo2 } } diff --git a/src/field/packed_avx2/mod.rs b/src/field/packed_avx2/mod.rs index eddbb5c9..5f6294a4 100644 --- a/src/field/packed_avx2/mod.rs +++ b/src/field/packed_avx2/mod.rs @@ -1,21 +1,21 @@ +mod avx2_prime_field; mod common; mod goldilocks; -mod packed_prime_field; -use packed_prime_field::PackedPrimeField; +use avx2_prime_field::Avx2PrimeField; use crate::field::goldilocks_field::GoldilocksField; -pub type PackedGoldilocksAVX2 = PackedPrimeField; +pub type PackedGoldilocksAvx2 = Avx2PrimeField; #[cfg(test)] mod tests { use crate::field::goldilocks_field::GoldilocksField; - use crate::field::packed_avx2::common::ReducibleAVX2; - use crate::field::packed_avx2::packed_prime_field::PackedPrimeField; + use crate::field::packed_avx2::avx2_prime_field::Avx2PrimeField; + use crate::field::packed_avx2::common::ReducibleAvx2; use crate::field::packed_field::PackedField; - fn test_vals_a() -> [F; 4] { + fn test_vals_a() -> [F; 4] { [ F::from_noncanonical_u64(14479013849828404771), F::from_noncanonical_u64(9087029921428221768), @@ -23,7 +23,7 @@ mod tests { F::from_noncanonical_u64(5646033492608483824), ] } - fn test_vals_b() -> [F; 4] { + fn test_vals_b() -> [F; 4] { [ F::from_noncanonical_u64(17891926589593242302), F::from_noncanonical_u64(11009798273260028228), @@ -32,17 +32,17 @@ mod tests { ] } - fn test_add() + fn test_add() where - [(); PackedPrimeField::::WIDTH]: , + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); - let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); + let packed_b = Avx2PrimeField::::from_arr(b_arr); let packed_res = packed_a + packed_b; - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a + b); for (exp, res) in expected.zip(arr_res) { @@ -50,17 +50,17 @@ mod tests { } } - fn test_mul() + fn test_mul() where - [(); PackedPrimeField::::WIDTH]: , + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); - let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); + let packed_b = Avx2PrimeField::::from_arr(b_arr); let packed_res = packed_a * packed_b; - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a * b); for (exp, res) in expected.zip(arr_res) { @@ -68,15 +68,15 @@ mod tests { } } - fn test_square() + fn test_square() where - [(); PackedPrimeField::::WIDTH]: , + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); let packed_res = packed_a.square(); - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().map(|&a| a.square()); for (exp, res) in expected.zip(arr_res) { @@ -84,15 +84,15 @@ mod tests { } } - fn test_neg() + fn test_neg() where - [(); PackedPrimeField::::WIDTH]: , + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); let packed_res = -packed_a; - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().map(|&a| -a); for (exp, res) in expected.zip(arr_res) { @@ -100,17 +100,17 @@ mod tests { } } - fn test_sub() + fn test_sub() where - [(); PackedPrimeField::::WIDTH]: , + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); - let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); + let packed_b = Avx2PrimeField::::from_arr(b_arr); let packed_res = packed_a - packed_b; - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a - b); for (exp, res) in expected.zip(arr_res) { @@ -118,33 +118,39 @@ mod tests { } } - fn test_interleave_is_involution() + fn test_interleave_is_involution() where - [(); PackedPrimeField::::WIDTH]: , + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); - let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); + let packed_b = Avx2PrimeField::::from_arr(b_arr); { // Interleave, then deinterleave. - let (x, y) = packed_a.interleave(packed_b, 0); - let (res_a, res_b) = x.interleave(y, 0); - assert_eq!(res_a.to_arr(), a_arr); - assert_eq!(res_b.to_arr(), b_arr); - } - { let (x, y) = packed_a.interleave(packed_b, 1); let (res_a, res_b) = x.interleave(y, 1); - assert_eq!(res_a.to_arr(), a_arr); - assert_eq!(res_b.to_arr(), b_arr); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); + } + { + let (x, y) = packed_a.interleave(packed_b, 2); + let (res_a, res_b) = x.interleave(y, 2); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); + } + { + let (x, y) = packed_a.interleave(packed_b, 4); + let (res_a, res_b) = x.interleave(y, 4); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); } } - fn test_interleave() + fn test_interleave() where - [(); PackedPrimeField::::WIDTH]: , + [(); Avx2PrimeField::::WIDTH]:, { let in_a: [F; 4] = [ F::from_noncanonical_u64(00), @@ -158,42 +164,47 @@ mod tests { F::from_noncanonical_u64(12), F::from_noncanonical_u64(13), ]; - let int0_a: [F; 4] = [ + let int1_a: [F; 4] = [ F::from_noncanonical_u64(00), F::from_noncanonical_u64(10), F::from_noncanonical_u64(02), F::from_noncanonical_u64(12), ]; - let int0_b: [F; 4] = [ + let int1_b: [F; 4] = [ F::from_noncanonical_u64(01), F::from_noncanonical_u64(11), F::from_noncanonical_u64(03), F::from_noncanonical_u64(13), ]; - let int1_a: [F; 4] = [ + let int2_a: [F; 4] = [ F::from_noncanonical_u64(00), F::from_noncanonical_u64(01), F::from_noncanonical_u64(10), F::from_noncanonical_u64(11), ]; - let int1_b: [F; 4] = [ + let int2_b: [F; 4] = [ F::from_noncanonical_u64(02), F::from_noncanonical_u64(03), F::from_noncanonical_u64(12), F::from_noncanonical_u64(13), ]; - let packed_a = PackedPrimeField::::from_arr(in_a); - let packed_b = PackedPrimeField::::from_arr(in_b); - { - let (x0, y0) = packed_a.interleave(packed_b, 0); - assert_eq!(x0.to_arr(), int0_a); - assert_eq!(y0.to_arr(), int0_b); - } + let packed_a = Avx2PrimeField::::from_arr(in_a); + let packed_b = Avx2PrimeField::::from_arr(in_b); { let (x1, y1) = packed_a.interleave(packed_b, 1); - assert_eq!(x1.to_arr(), int1_a); - assert_eq!(y1.to_arr(), int1_b); + assert_eq!(x1.as_arr(), int1_a); + assert_eq!(y1.as_arr(), int1_b); + } + { + let (x2, y2) = packed_a.interleave(packed_b, 2); + assert_eq!(x2.as_arr(), int2_a); + assert_eq!(y2.as_arr(), int2_b); + } + { + let (x4, y4) = packed_a.interleave(packed_b, 4); + assert_eq!(x4.as_arr(), in_a); + assert_eq!(y4.as_arr(), in_b); } } diff --git a/src/field/packed_field.rs b/src/field/packed_field.rs index a4b1945a..00b99d6c 100644 --- a/src/field/packed_field.rs +++ b/src/field/packed_field.rs @@ -1,77 +1,81 @@ -use std::fmt; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::iter::{Product, Sum}; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::slice; use crate::field::field_types::Field; -pub trait PackedField: +/// # Safety +/// - WIDTH is assumed to be a power of 2. +/// - If P implements PackedField then P must be castable to/from [P::Scalar; P::WIDTH] without UB. +pub unsafe trait PackedField: 'static + Add - + Add + + Add + AddAssign - + AddAssign + + AddAssign + Copy + Debug + Default - // TODO: Implementing Div sounds like a pain so it's a worry for later. + + From + // TODO: Implement packed / packed division + + Div + Mul - + Mul + + Mul + MulAssign - + MulAssign + + MulAssign + Neg + Product + Send + Sub - + Sub + + Sub + SubAssign - + SubAssign + + SubAssign + Sum + Sync +where + Self::Scalar: Add, + Self::Scalar: Mul, + Self::Scalar: Sub, { - type FieldType: Field; + type Scalar: Field; - const LOG2_WIDTH: usize; - const WIDTH: usize = 1 << Self::LOG2_WIDTH; + const WIDTH: usize; + const ZERO: Self; + const ONE: Self; fn square(&self) -> Self { *self * *self } - fn zero() -> Self { - Self::broadcast(Self::FieldType::ZERO) - } - fn one() -> Self { - Self::broadcast(Self::FieldType::ONE) - } + fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self; + fn as_arr(&self) -> [Self::Scalar; Self::WIDTH]; - fn broadcast(x: Self::FieldType) -> Self; + fn from_slice(slice: &[Self::Scalar]) -> &Self; + fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self; + fn as_slice(&self) -> &[Self::Scalar]; + fn as_slice_mut(&mut self) -> &mut [Self::Scalar]; - fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self; - fn to_arr(&self) -> [Self::FieldType; Self::WIDTH]; - - fn from_slice(slice: &[Self::FieldType]) -> Self; - fn to_vec(&self) -> Vec; - - /// Take interpret two vectors as chunks of (1 << r) elements. Unpack and interleave those + /// Take interpret two vectors as chunks of block_len elements. Unpack and interleave those /// chunks. This is best seen with an example. If we have: /// A = [x0, y0, x1, y1], /// B = [x2, y2, x3, y3], /// then - /// interleave(A, B, 0) = ([x0, x2, x1, x3], [y0, y2, y1, y3]). + /// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3]). /// Pairs that were adjacent in the input are at corresponding positions in the output. - /// r lets us set the size of chunks we're interleaving. If we set r = 1, then for + /// r lets us set the size of chunks we're interleaving. If we set block_len = 2, then for /// A = [x0, x1, y0, y1], /// B = [x2, x3, y2, y3], /// we obtain - /// interleave(A, B, r) = ([x0, x1, x2, x3], [y0, y1, y2, y3]). + /// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3]). /// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and /// transposing those matrices. - /// When r = LOG2_WIDTH, this operation is a no-op. Values of r > LOG2_WIDTH are not - /// permitted. - fn interleave(&self, other: Self, r: usize) -> (Self, Self); + /// When block_len = WIDTH, this operation is a no-op. block_len must divide WIDTH. Since + /// WIDTH is specified to be a power of 2, block_len must also be a power of 2. It cannot be 0 + /// and it cannot be > WIDTH. + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self); - fn pack_slice(buf: &[Self::FieldType]) -> &[Self] { + fn pack_slice(buf: &[Self::Scalar]) -> &[Self] { assert!( buf.len() % Self::WIDTH == 0, "Slice length (got {}) must be a multiple of packed field width ({}).", @@ -82,7 +86,7 @@ pub trait PackedField: let n = buf.len() / Self::WIDTH; unsafe { std::slice::from_raw_parts(buf_ptr, n) } } - fn pack_slice_mut(buf: &mut [Self::FieldType]) -> &mut [Self] { + fn pack_slice_mut(buf: &mut [Self::Scalar]) -> &mut [Self] { assert!( buf.len() % Self::WIDTH == 0, "Slice length (got {}) must be a multiple of packed field width ({}).", @@ -95,143 +99,41 @@ pub trait PackedField: } } -#[derive(Copy, Clone)] -#[repr(transparent)] -pub struct Singleton(pub F); +unsafe impl PackedField for F { + type Scalar = Self; -impl Add for Singleton { - type Output = Self; - fn add(self, rhs: Self) -> Self { - Self(self.0 + rhs.0) - } -} -impl Add for Singleton { - type Output = Self; - fn add(self, rhs: F) -> Self { - self + Self::broadcast(rhs) - } -} -impl AddAssign for Singleton { - fn add_assign(&mut self, rhs: Self) { - *self = *self + rhs; - } -} -impl AddAssign for Singleton { - fn add_assign(&mut self, rhs: F) { - *self = *self + rhs; - } -} - -impl Debug for Singleton { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "({:?})", self.0) - } -} - -impl Default for Singleton { - fn default() -> Self { - Self::zero() - } -} - -impl Mul for Singleton { - type Output = Self; - fn mul(self, rhs: Self) -> Self { - Self(self.0 * rhs.0) - } -} -impl Mul for Singleton { - type Output = Self; - fn mul(self, rhs: F) -> Self { - self * Self::broadcast(rhs) - } -} -impl MulAssign for Singleton { - fn mul_assign(&mut self, rhs: Self) { - *self = *self * rhs; - } -} -impl MulAssign for Singleton { - fn mul_assign(&mut self, rhs: F) { - *self = *self * rhs; - } -} - -impl Neg for Singleton { - type Output = Self; - fn neg(self) -> Self { - Self(-self.0) - } -} - -impl Product for Singleton { - fn product>(iter: I) -> Self { - Self(iter.map(|x| x.0).product()) - } -} - -impl PackedField for Singleton { - const LOG2_WIDTH: usize = 0; - type FieldType = F; - - fn broadcast(x: F) -> Self { - Self(x) - } - - fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self { - Self(arr[0]) - } - - fn to_arr(&self) -> [Self::FieldType; Self::WIDTH] { - [self.0] - } - - fn from_slice(slice: &[Self::FieldType]) -> Self { - assert!(slice.len() == 1); - Self(slice[0]) - } - - fn to_vec(&self) -> Vec { - vec![self.0] - } - - fn interleave(&self, other: Self, r: usize) -> (Self, Self) { - match r { - 0 => (*self, other), // This is a no-op whenever r == LOG2_WIDTH. - _ => panic!("r cannot be more than LOG2_WIDTH"), - } - } + const WIDTH: usize = 1; + const ZERO: Self = ::ZERO; + const ONE: Self = ::ONE; fn square(&self) -> Self { - Self(self.0.square()) + ::square(self) } -} -impl Sub for Singleton { - type Output = Self; - fn sub(self, rhs: Self) -> Self { - Self(self.0 - rhs.0) + fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self { + arr[0] } -} -impl Sub for Singleton { - type Output = Self; - fn sub(self, rhs: F) -> Self { - self - Self::broadcast(rhs) + fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] { + [*self] } -} -impl SubAssign for Singleton { - fn sub_assign(&mut self, rhs: Self) { - *self = *self - rhs; - } -} -impl SubAssign for Singleton { - fn sub_assign(&mut self, rhs: F) { - *self = *self - rhs; - } -} -impl Sum for Singleton { - fn sum>(iter: I) -> Self { - Self(iter.map(|x| x.0).sum()) + fn from_slice(slice: &[Self::Scalar]) -> &Self { + &slice[0] + } + fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self { + &mut slice[0] + } + fn as_slice(&self) -> &[Self::Scalar] { + slice::from_ref(self) + } + fn as_slice_mut(&mut self) -> &mut [Self::Scalar] { + slice::from_mut(self) + } + + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) { + match block_len { + 1 => (*self, other), + _ => panic!("unsupported block length"), + } } } diff --git a/src/field/prime_field_testing.rs b/src/field/prime_field_testing.rs index 4febc3a8..1b7b97eb 100644 --- a/src/field/prime_field_testing.rs +++ b/src/field/prime_field_testing.rs @@ -24,7 +24,7 @@ where ExpectedOp: Fn(u64) -> u64, { let inputs = test_inputs(F::ORDER); - let expected: Vec<_> = inputs.iter().map(|x| expected_op(x.clone())).collect(); + let expected: Vec<_> = inputs.iter().map(|&x| expected_op(x)).collect(); let output: Vec<_> = inputs .iter() .cloned() @@ -144,7 +144,7 @@ macro_rules! test_prime_field_arithmetic { fn inverse_2exp() { type F = $field; - let v = ::PrimeField::TWO_ADICITY; + let v = ::TWO_ADICITY; for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123 * v] { let x = F::TWO.exp_u64(e as u64); diff --git a/src/field/secp256k1.rs b/src/field/secp256k1_base.rs similarity index 86% rename from src/field/secp256k1.rs rename to src/field/secp256k1_base.rs index 56d506d6..0d79000f 100644 --- a/src/field/secp256k1.rs +++ b/src/field/secp256k1_base.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::fmt; use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; @@ -12,7 +11,6 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use crate::field::field_types::Field; -use crate::field::goldilocks_field::GoldilocksField; /// The base field of the secp256k1 elliptic curve. /// @@ -36,8 +34,80 @@ fn biguint_from_array(arr: [u64; 4]) -> BigUint { ]) } -impl Secp256K1Base { - fn to_canonical_biguint(&self) -> BigUint { +impl Default for Secp256K1Base { + fn default() -> Self { + Self::ZERO + } +} + +impl PartialEq for Secp256K1Base { + fn eq(&self, other: &Self) -> bool { + self.to_biguint() == other.to_biguint() + } +} + +impl Eq for Secp256K1Base {} + +impl Hash for Secp256K1Base { + fn hash(&self, state: &mut H) { + self.to_biguint().hash(state) + } +} + +impl Display for Secp256K1Base { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.to_biguint(), f) + } +} + +impl Debug for Secp256K1Base { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.to_biguint(), f) + } +} + +impl Field for Secp256K1Base { + const ZERO: Self = Self([0; 4]); + const ONE: Self = Self([1, 0, 0, 0]); + const TWO: Self = Self([2, 0, 0, 0]); + const NEG_ONE: Self = Self([ + 0xFFFFFFFEFFFFFC2E, + 0xFFFFFFFFFFFFFFFF, + 0xFFFFFFFFFFFFFFFF, + 0xFFFFFFFFFFFFFFFF, + ]); + + const TWO_ADICITY: usize = 1; + const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY; + + // Sage: `g = GF(p).multiplicative_generator()` + const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]); + + // Sage: `g_2 = g^((p - 1) / 2)` + const POWER_OF_TWO_GENERATOR: Self = Self::NEG_ONE; + + const BITS: usize = 256; + + fn order() -> BigUint { + BigUint::from_slice(&[ + 0xFFFFFC2F, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0xFFFFFFFF, + ]) + } + fn characteristic() -> BigUint { + Self::order() + } + + fn try_inverse(&self) -> Option { + if self.is_zero() { + return None; + } + + // Fermat's Little Theorem + Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) + } + + fn to_biguint(&self) -> BigUint { let mut result = biguint_from_array(self.0); if result >= Self::order() { result -= Self::order(); @@ -55,79 +125,6 @@ impl Secp256K1Base { .expect("error converting to u64 array"), ) } -} - -impl Default for Secp256K1Base { - fn default() -> Self { - Self::ZERO - } -} - -impl PartialEq for Secp256K1Base { - fn eq(&self, other: &Self) -> bool { - self.to_canonical_biguint() == other.to_canonical_biguint() - } -} - -impl Eq for Secp256K1Base {} - -impl Hash for Secp256K1Base { - fn hash(&self, state: &mut H) { - self.to_canonical_biguint().hash(state) - } -} - -impl Display for Secp256K1Base { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Display::fmt(&self.to_canonical_biguint(), f) - } -} - -impl Debug for Secp256K1Base { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Debug::fmt(&self.to_canonical_biguint(), f) - } -} - -impl Field for Secp256K1Base { - // TODO: fix - type PrimeField = GoldilocksField; - - const ZERO: Self = Self([0; 4]); - const ONE: Self = Self([1, 0, 0, 0]); - const TWO: Self = Self([2, 0, 0, 0]); - const NEG_ONE: Self = Self([ - 0xFFFFFFFEFFFFFC2E, - 0xFFFFFFFFFFFFFFFF, - 0xFFFFFFFFFFFFFFFF, - 0xFFFFFFFFFFFFFFFF, - ]); - - // TODO: fix - const CHARACTERISTIC: u64 = 0; - const TWO_ADICITY: usize = 1; - - // Sage: `g = GF(p).multiplicative_generator()` - const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]); - - // Sage: `g_2 = g^((p - 1) / 2)` - const POWER_OF_TWO_GENERATOR: Self = Self::NEG_ONE; - - fn order() -> BigUint { - BigUint::from_slice(&[ - 0xFFFFFC2F, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, - ]) - } - - fn try_inverse(&self) -> Option { - if self.is_zero() { - return None; - } - - // Fermat's Little Theorem - Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) - } #[inline] fn from_canonical_u64(n: u64) -> Self { @@ -157,7 +154,7 @@ impl Neg for Secp256K1Base { if self.is_zero() { Self::ZERO } else { - Self::from_biguint(Self::order() - self.to_canonical_biguint()) + Self::from_biguint(Self::order() - self.to_biguint()) } } } @@ -167,7 +164,7 @@ impl Add for Secp256K1Base { #[inline] fn add(self, rhs: Self) -> Self { - let mut result = self.to_canonical_biguint() + rhs.to_canonical_biguint(); + let mut result = self.to_biguint() + rhs.to_biguint(); if result >= Self::order() { result -= Self::order(); } @@ -210,9 +207,7 @@ impl Mul for Secp256K1Base { #[inline] fn mul(self, rhs: Self) -> Self { - Self::from_biguint( - (self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()), - ) + Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order())) } } @@ -244,3 +239,10 @@ impl DivAssign for Secp256K1Base { *self = *self / rhs; } } + +#[cfg(test)] +mod tests { + use crate::test_field_arithmetic; + + test_field_arithmetic!(crate::field::secp256k1_base::Secp256K1Base); +} diff --git a/src/field/secp256k1_scalar.rs b/src/field/secp256k1_scalar.rs new file mode 100644 index 00000000..a5b7a315 --- /dev/null +++ b/src/field/secp256k1_scalar.rs @@ -0,0 +1,257 @@ +use std::convert::TryInto; +use std::fmt; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::iter::{Product, Sum}; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use itertools::Itertools; +use num::bigint::{BigUint, RandBigInt}; +use num::{Integer, One}; +use rand::Rng; +use serde::{Deserialize, Serialize}; + +use crate::field::field_types::Field; + +/// The base field of the secp256k1 elliptic curve. +/// +/// Its order is +/// ```ignore +/// P = 0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141 +/// = 115792089237316195423570985008687907852837564279074904382605163141518161494337 +/// = 2**256 - 432420386565659656852420866394968145599 +/// ``` +#[derive(Copy, Clone, Serialize, Deserialize)] +pub struct Secp256K1Scalar(pub [u64; 4]); + +fn biguint_from_array(arr: [u64; 4]) -> BigUint { + BigUint::from_slice(&[ + arr[0] as u32, + (arr[0] >> 32) as u32, + arr[1] as u32, + (arr[1] >> 32) as u32, + arr[2] as u32, + (arr[2] >> 32) as u32, + arr[3] as u32, + (arr[3] >> 32) as u32, + ]) +} + +impl Default for Secp256K1Scalar { + fn default() -> Self { + Self::ZERO + } +} + +impl PartialEq for Secp256K1Scalar { + fn eq(&self, other: &Self) -> bool { + self.to_biguint() == other.to_biguint() + } +} + +impl Eq for Secp256K1Scalar {} + +impl Hash for Secp256K1Scalar { + fn hash(&self, state: &mut H) { + self.to_biguint().hash(state) + } +} + +impl Display for Secp256K1Scalar { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.to_biguint(), f) + } +} + +impl Debug for Secp256K1Scalar { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.to_biguint(), f) + } +} + +impl Field for Secp256K1Scalar { + const ZERO: Self = Self([0; 4]); + const ONE: Self = Self([1, 0, 0, 0]); + const TWO: Self = Self([2, 0, 0, 0]); + const NEG_ONE: Self = Self([ + 0xBFD25E8CD0364140, + 0xBAAEDCE6AF48A03B, + 0xFFFFFFFFFFFFFFFE, + 0xFFFFFFFFFFFFFFFF, + ]); + + const TWO_ADICITY: usize = 6; + const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY; + + // Sage: `g = GF(p).multiplicative_generator()` + const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([7, 0, 0, 0]); + + // Sage: `g_2 = power_mod(g, (p - 1) // 2^6), p)` + // 5480320495727936603795231718619559942670027629901634955707709633242980176626 + const POWER_OF_TWO_GENERATOR: Self = Self([ + 0x992f4b5402b052f2, + 0x98BDEAB680756045, + 0xDF9879A3FBC483A8, + 0xC1DC060E7A91986, + ]); + + const BITS: usize = 256; + + fn order() -> BigUint { + BigUint::from_slice(&[ + 0xD0364141, 0xBFD25E8C, 0xAF48A03B, 0xBAAEDCE6, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, + 0xFFFFFFFF, + ]) + } + fn characteristic() -> BigUint { + Self::order() + } + + fn try_inverse(&self) -> Option { + if self.is_zero() { + return None; + } + + // Fermat's Little Theorem + Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) + } + + fn to_biguint(&self) -> BigUint { + let mut result = biguint_from_array(self.0); + if result >= Self::order() { + result -= Self::order(); + } + result + } + + fn from_biguint(val: BigUint) -> Self { + Self( + val.to_u64_digits() + .into_iter() + .pad_using(4, |_| 0) + .collect::>()[..] + .try_into() + .expect("error converting to u64 array"), + ) + } + + #[inline] + fn from_canonical_u64(n: u64) -> Self { + Self([n, 0, 0, 0]) + } + + #[inline] + fn from_noncanonical_u128(n: u128) -> Self { + Self([n as u64, (n >> 64) as u64, 0, 0]) + } + + #[inline] + fn from_noncanonical_u96(n: (u64, u32)) -> Self { + Self([n.0, n.1 as u64, 0, 0]) + } + + fn rand_from_rng(rng: &mut R) -> Self { + Self::from_biguint(rng.gen_biguint_below(&Self::order())) + } +} + +impl Neg for Secp256K1Scalar { + type Output = Self; + + #[inline] + fn neg(self) -> Self { + if self.is_zero() { + Self::ZERO + } else { + Self::from_biguint(Self::order() - self.to_biguint()) + } + } +} + +impl Add for Secp256K1Scalar { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self { + let mut result = self.to_biguint() + rhs.to_biguint(); + if result >= Self::order() { + result -= Self::order(); + } + Self::from_biguint(result) + } +} + +impl AddAssign for Secp256K1Scalar { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl Sum for Secp256K1Scalar { + fn sum>(iter: I) -> Self { + iter.fold(Self::ZERO, |acc, x| acc + x) + } +} + +impl Sub for Secp256K1Scalar { + type Output = Self; + + #[inline] + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, rhs: Self) -> Self { + self + -rhs + } +} + +impl SubAssign for Secp256K1Scalar { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl Mul for Secp256K1Scalar { + type Output = Self; + + #[inline] + fn mul(self, rhs: Self) -> Self { + Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order())) + } +} + +impl MulAssign for Secp256K1Scalar { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl Product for Secp256K1Scalar { + #[inline] + fn product>(iter: I) -> Self { + iter.reduce(|acc, x| acc * x).unwrap_or(Self::ONE) + } +} + +impl Div for Secp256K1Scalar { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: Self) -> Self::Output { + self * rhs.inverse() + } +} + +impl DivAssign for Secp256K1Scalar { + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } +} + +#[cfg(test)] +mod tests { + use crate::test_field_arithmetic; + + test_field_arithmetic!(crate::field::secp256k1_scalar::Secp256K1Scalar); +} diff --git a/src/fri/commitment.rs b/src/fri/commitment.rs index 195d082b..54a2ebbc 100644 --- a/src/fri/commitment.rs +++ b/src/fri/commitment.rs @@ -11,14 +11,14 @@ use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::config::GenericConfig; use crate::plonk::plonk_common::PlonkPolynomials; use crate::plonk::proof::OpeningSet; -use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; use crate::util::reducing::ReducingFactor; use crate::util::timing::TimingTree; use crate::util::{log2_strict, reverse_bits, reverse_index_bits_in_place, transpose}; -/// Two (~64 bit) field elements gives ~128 bit security. -pub const SALT_SIZE: usize = 2; +/// Four (~64 bit) field elements gives ~128 bit security. +pub const SALT_SIZE: usize = 4; /// Represents a batch FRI based commitment to a list of polynomials. pub struct PolynomialBatchCommitment, C: GenericConfig, const D: usize> { diff --git a/src/fri/mod.rs b/src/fri/mod.rs index 2419b06b..bfb2ebfd 100644 --- a/src/fri/mod.rs +++ b/src/fri/mod.rs @@ -36,8 +36,4 @@ impl FriParams { pub(crate) fn max_arity_bits(&self) -> Option { self.reduction_arity_bits.iter().copied().max() } - - pub(crate) fn max_arity(&self) -> Option { - self.max_arity_bits().map(|bits| 1 << bits) - } } diff --git a/src/fri/proof.rs b/src/fri/proof.rs index c3e69971..ccd80a37 100644 --- a/src/fri/proof.rs +++ b/src/fri/proof.rs @@ -16,7 +16,7 @@ use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::plonk_common::PolynomialsIndexBlinding; use crate::plonk::proof::{FriInferredElements, ProofChallenges}; -use crate::polynomial::polynomial::PolynomialCoeffs; +use crate::polynomial::PolynomialCoeffs; /// Evaluations and Merkle proof produced by the prover in a FRI query step. #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] diff --git a/src/fri/prover.rs b/src/fri/prover.rs index ac55be41..e902986b 100644 --- a/src/fri/prover.rs +++ b/src/fri/prover.rs @@ -9,7 +9,7 @@ use crate::iop::challenger::Challenger; use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::plonk_common::reduce_with_powers; -use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; use crate::util::reverse_index_bits_in_place; use crate::util::timing::TimingTree; diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index 18c696da..86e30602 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -3,13 +3,16 @@ use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::fri::proof::{FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget}; use crate::fri::FriConfig; +use crate::gadgets::interpolation::InterpolationGate; use crate::gates::gate::Gate; -use crate::gates::interpolation::InterpolationGate; +use crate::gates::interpolation::HighDegreeInterpolationGate; +use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::gates::random_access::RandomAccessGate; use crate::hash::hash_types::MerkleCapTarget; use crate::iop::challenger::RecursiveChallenger; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData}; use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::config::{AlgebraicConfig, AlgebraicHasher, GenericConfig}; use crate::plonk::plonk_common::PlonkPolynomials; @@ -28,6 +31,7 @@ impl, const D: usize> CircuitBuilder { arity_bits: usize, evals: &[ExtensionTarget], beta: ExtensionTarget, + common_data: &CommonCircuitData, ) -> ExtensionTarget { let arity = 1 << arity_bits; debug_assert_eq!(evals.len(), arity); @@ -44,37 +48,62 @@ impl, const D: usize> CircuitBuilder { let coset_start = self.mul(start, x); // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. - self.interpolate_coset(arity_bits, coset_start, &evals, beta) + // `HighDegreeInterpolationGate` has degree `arity`, so we use the low-degree gate if + // the arity is too large. + if arity > common_data.quotient_degree_factor { + self.interpolate_coset::>( + arity_bits, + coset_start, + &evals, + beta, + ) + } else { + self.interpolate_coset::>( + arity_bits, + coset_start, + &evals, + beta, + ) + } } /// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check /// isn't required -- without it we'd get errors elsewhere in the stack -- but just gives more /// helpful errors. - fn check_recursion_config(&self, max_fri_arity: usize) { + fn check_recursion_config( + &self, + max_fri_arity_bits: usize, + common_data: &CommonCircuitData, + ) { let random_access = RandomAccessGate::::new_from_config( &self.config, - max_fri_arity.max(1 << self.config.cap_height), + max_fri_arity_bits.max(self.config.cap_height), ); - let interpolation_gate = InterpolationGate::::new(log2_strict(max_fri_arity)); + let (interpolation_wires, interpolation_routed_wires) = + if 1 << max_fri_arity_bits > common_data.quotient_degree_factor { + let gate = LowDegreeInterpolationGate::::new(max_fri_arity_bits); + (gate.num_wires(), gate.num_routed_wires()) + } else { + let gate = HighDegreeInterpolationGate::::new(max_fri_arity_bits); + (gate.num_wires(), gate.num_routed_wires()) + }; - let min_wires = random_access - .num_wires() - .max(interpolation_gate.num_wires()); + let min_wires = random_access.num_wires().max(interpolation_wires); let min_routed_wires = random_access .num_routed_wires() - .max(interpolation_gate.num_routed_wires()); + .max(interpolation_routed_wires); assert!( self.config.num_wires >= min_wires, - "To efficiently perform FRI checks with an arity of {}, at least {} wires are needed. Consider reducing arity.", - max_fri_arity, + "To efficiently perform FRI checks with an arity of 2^{}, at least {} wires are needed. Consider reducing arity.", + max_fri_arity_bits, min_wires ); assert!( self.config.num_routed_wires >= min_routed_wires, - "To efficiently perform FRI checks with an arity of {}, at least {} routed wires are needed. Consider reducing arity.", - max_fri_arity, + "To efficiently perform FRI checks with an arity of 2^{}, at least {} routed wires are needed. Consider reducing arity.", + max_fri_arity_bits, min_routed_wires ); } @@ -108,8 +137,8 @@ impl, const D: usize> CircuitBuilder { ) { let config = &common_data.config; - if let Some(max_arity) = common_data.fri_params.max_arity() { - self.check_recursion_config(max_arity); + if let Some(max_arity_bits) = common_data.fri_params.max_arity_bits() { + self.check_recursion_config(max_arity_bits, common_data); } debug_assert_eq!( @@ -233,7 +262,7 @@ impl, const D: usize> CircuitBuilder { common_data: &CommonCircuitData, ) -> ExtensionTarget { assert!(D > 1, "Not implemented for D=1."); - let config = self.config.clone(); + let config = &common_data.config; let degree_log = common_data.degree_bits; debug_assert_eq!( degree_log, @@ -306,9 +335,13 @@ impl, const D: usize> CircuitBuilder { common_data: &CommonCircuitData, ) { let n_log = log2_strict(n); - // TODO: Do we need to range check `x_index` to a target smaller than `p`? + + // Note that this `low_bits` decomposition permits non-canonical binary encodings. Here we + // verify that this has a negligible impact on soundness error. + Self::assert_noncanonical_indices_ok(&common_data.config); let x_index = challenger.get_challenge(self); - let mut x_index_bits = self.low_bits(x_index, n_log, 64); + let mut x_index_bits = self.low_bits(x_index, n_log, F::BITS); + let cap_index = self.le_sum(x_index_bits[x_index_bits.len() - common_data.config.cap_height..].iter()); with_context!( @@ -376,6 +409,7 @@ impl, const D: usize> CircuitBuilder { arity_bits, evals, betas[i], + common_data ) ); @@ -409,6 +443,26 @@ impl, const D: usize> CircuitBuilder { ); self.connect_extension(eval, old_eval); } + + /// We decompose FRI query indices into bits without verifying that the decomposition given by + /// the prover is the canonical one. In particular, if `x_index < 2^field_bits - p`, then the + /// prover could supply the binary encoding of either `x_index` or `x_index + p`, since the are + /// congruent mod `p`. However, this only occurs with probability + /// p_ambiguous = (2^field_bits - p) / p + /// which is small for the field that we use in practice. + /// + /// In particular, the soundness error of one FRI query is roughly the codeword rate, which + /// is much larger than this ambiguous-element probability given any reasonable parameters. + /// Thus ambiguous elements contribute a negligible amount to soundness error. + /// + /// Here we compare the probabilities as a sanity check, to verify the claim above. + fn assert_noncanonical_indices_ok(config: &CircuitConfig) { + let num_ambiguous_elems = u64::MAX - F::ORDER + 1; + let query_error = config.rate(); + let p_ambiguous = (num_ambiguous_elems as f64) / (F::ORDER as f64); + assert!(p_ambiguous < query_error * 1e-5, + "A non-negligible portion of field elements are in the range that permits non-canonical encodings. Need to do more analysis or enforce canonical encodings."); + } } #[derive(Copy, Clone)] diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 74cc890d..1d8a4835 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -2,6 +2,8 @@ use std::borrow::Borrow; use crate::field::extension_field::Extendable; use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::field::field_types::{PrimeField, RichField}; +use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::exponentiation::ExponentiationGate; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -32,18 +34,117 @@ impl, const D: usize> CircuitBuilder { multiplicand_1: Target, addend: Target, ) -> Target { - let multiplicand_0_ext = self.convert_to_ext(multiplicand_0); - let multiplicand_1_ext = self.convert_to_ext(multiplicand_1); - let addend_ext = self.convert_to_ext(addend); + // If we're not configured to use the base arithmetic gate, just call arithmetic_extension. + if !self.config.use_base_arithmetic_gate { + let multiplicand_0_ext = self.convert_to_ext(multiplicand_0); + let multiplicand_1_ext = self.convert_to_ext(multiplicand_1); + let addend_ext = self.convert_to_ext(addend); - self.arithmetic_extension( + return self + .arithmetic_extension( + const_0, + const_1, + multiplicand_0_ext, + multiplicand_1_ext, + addend_ext, + ) + .0[0]; + } + + // See if we can determine the result without adding an `ArithmeticGate`. + if let Some(result) = + self.arithmetic_special_cases(const_0, const_1, multiplicand_0, multiplicand_1, addend) + { + return result; + } + + // See if we've already computed the same operation. + let operation = BaseArithmeticOperation { const_0, const_1, - multiplicand_0_ext, - multiplicand_1_ext, - addend_ext, - ) - .0[0] + multiplicand_0, + multiplicand_1, + addend, + }; + if let Some(&result) = self.base_arithmetic_results.get(&operation) { + return result; + } + + // Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot. + let result = self.add_base_arithmetic_operation(operation); + self.base_arithmetic_results.insert(operation, result); + result + } + + fn add_base_arithmetic_operation(&mut self, operation: BaseArithmeticOperation) -> Target { + let (gate, i) = self.find_base_arithmetic_gate(operation.const_0, operation.const_1); + let wires_multiplicand_0 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_0(i)); + let wires_multiplicand_1 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_1(i)); + let wires_addend = Target::wire(gate, ArithmeticGate::wire_ith_addend(i)); + + self.connect(operation.multiplicand_0, wires_multiplicand_0); + self.connect(operation.multiplicand_1, wires_multiplicand_1); + self.connect(operation.addend, wires_addend); + + Target::wire(gate, ArithmeticGate::wire_ith_output(i)) + } + + /// Checks for special cases where the value of + /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` + /// can be determined without adding an `ArithmeticGate`. + fn arithmetic_special_cases( + &mut self, + const_0: F, + const_1: F, + multiplicand_0: Target, + multiplicand_1: Target, + addend: Target, + ) -> Option { + let zero = self.zero(); + + let mul_0_const = self.target_as_constant(multiplicand_0); + let mul_1_const = self.target_as_constant(multiplicand_1); + let addend_const = self.target_as_constant(addend); + + let first_term_zero = + const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero; + let second_term_zero = const_1 == F::ZERO || addend == zero; + + // If both terms are constant, return their (constant) sum. + let first_term_const = if first_term_zero { + Some(F::ZERO) + } else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) { + Some(x * y * const_0) + } else { + None + }; + let second_term_const = if second_term_zero { + Some(F::ZERO) + } else { + addend_const.map(|x| x * const_1) + }; + if let (Some(x), Some(y)) = (first_term_const, second_term_const) { + return Some(self.constant(x + y)); + } + + if first_term_zero && const_1.is_one() { + return Some(addend); + } + + if second_term_zero { + if let Some(x) = mul_0_const { + if (x * const_0).is_one() { + return Some(multiplicand_1); + } + } + if let Some(x) = mul_1_const { + if (x * const_0).is_one() { + return Some(multiplicand_0); + } + } + } + + None } /// Computes `x * y + z`. @@ -53,20 +154,20 @@ impl, const D: usize> CircuitBuilder { /// Computes `x + C`. pub fn add_const(&mut self, x: Target, c: F) -> Target { - let one = self.one(); - self.arithmetic(F::ONE, c, one, x, one) + let c = self.constant(c); + self.add(x, c) } /// Computes `C * x`. pub fn mul_const(&mut self, c: F, x: Target) -> Target { - let zero = self.zero(); - self.mul_const_add(c, x, zero) + let c = self.constant(c); + self.mul(c, x) } /// Computes `C * x + y`. pub fn mul_const_add(&mut self, c: F, x: Target, y: Target) -> Target { - let one = self.one(); - self.arithmetic(c, F::ONE, x, one, y) + let c = self.constant(c); + self.mul_add(c, x, y) } /// Computes `x * y - z`. @@ -82,13 +183,8 @@ impl, const D: usize> CircuitBuilder { } /// Add `n` `Target`s. - // TODO: Can be made `D` times more efficient by using all wires of an `ArithmeticExtensionGate`. pub fn add_many(&mut self, terms: &[Target]) -> Target { - let terms_ext = terms - .iter() - .map(|&t| self.convert_to_ext(t)) - .collect::>(); - self.add_many_extension(&terms_ext).to_target_array()[0] + terms.iter().fold(self.zero(), |acc, &t| self.add(acc, t)) } /// Computes `x - y`. @@ -106,16 +202,16 @@ impl, const D: usize> CircuitBuilder { /// Multiply `n` `Target`s. pub fn mul_many(&mut self, terms: &[Target]) -> Target { - let terms_ext = terms + terms .iter() - .map(|&t| self.convert_to_ext(t)) - .collect::>(); - self.mul_many_extension(&terms_ext).to_target_array()[0] + .copied() + .reduce(|acc, t| self.mul(acc, t)) + .unwrap_or_else(|| self.one()) } /// Exponentiate `base` to the power of `2^power_log`. pub fn exp_power_of_2(&mut self, base: Target, power_log: usize) -> Target { - if power_log > ArithmeticExtensionGate::::new_from_config(&self.config).num_ops { + if power_log > self.num_base_arithmetic_ops_per_gate() { // Cheaper to just use `ExponentiateGate`. return self.exp_u64(base, 1 << power_log); } @@ -169,8 +265,7 @@ impl, const D: usize> CircuitBuilder { let base_t = self.constant(base); let exponent_bits: Vec<_> = exponent_bits.into_iter().map(|b| *b.borrow()).collect(); - if exponent_bits.len() > ArithmeticExtensionGate::::new_from_config(&self.config).num_ops - { + if exponent_bits.len() > self.num_base_arithmetic_ops_per_gate() { // Cheaper to just use `ExponentiateGate`. return self.exp_from_bits(base_t, exponent_bits); } @@ -220,3 +315,13 @@ impl, const D: usize> CircuitBuilder { self.inverse_extension(x_ext).0[0] } } + +/// Represents a base arithmetic operation in the circuit. Used to memoize results. +#[derive(Copy, Clone, Eq, PartialEq, Hash)] +pub(crate) struct BaseArithmeticOperation { + const_0: F, + const_1: F, + multiplicand_0: Target, + multiplicand_1: Target, + addend: Target, +} diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index cf62b9dd..b48b25bb 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -1,8 +1,9 @@ -use std::convert::TryInto; - use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::FieldExtension; use crate::field::extension_field::{Extendable, OEF}; +use crate::field::field_types::{Field, PrimeField, RichField}; +use crate::gates::arithmetic_extension::ArithmeticExtensionGate; +use crate::gates::multiplication_extension::MulExtensionGate; use crate::field::field_types::{Field, PrimeField}; use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; @@ -12,33 +13,6 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::bits_u64; impl, const D: usize> CircuitBuilder { - /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. - /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index - /// `g` and the gate's `i`-th operation is available. - fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { - let (gate, i) = self - .free_arithmetic - .get(&(const_0, const_1)) - .copied() - .unwrap_or_else(|| { - let gate = self.add_gate( - ArithmeticExtensionGate::new_from_config(&self.config), - vec![const_0, const_1], - ); - (gate, 0) - }); - - // Update `free_arithmetic` with new values. - if i < ArithmeticExtensionGate::::num_ops(&self.config) - 1 { - self.free_arithmetic - .insert((const_0, const_1), (gate, i + 1)); - } else { - self.free_arithmetic.remove(&(const_0, const_1)); - } - - (gate, i) - } - pub fn arithmetic_extension( &mut self, const_0: F, @@ -59,7 +33,7 @@ impl, const D: usize> CircuitBuilder { } // See if we've already computed the same operation. - let operation = ArithmeticOperation { + let operation = ExtensionArithmeticOperation { const_0, const_1, multiplicand_0, @@ -70,15 +44,21 @@ impl, const D: usize> CircuitBuilder { return result; } + let result = if self.target_as_constant_ext(addend) == Some(F::Extension::ZERO) { + // If the addend is zero, we use a multiplication gate. + self.compute_mul_extension_operation(operation) + } else { + // Otherwise, we use an arithmetic gate. + self.compute_arithmetic_extension_operation(operation) + }; // Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot. - let result = self.add_arithmetic_extension_operation(operation); self.arithmetic_results.insert(operation, result); result } - fn add_arithmetic_extension_operation( + fn compute_arithmetic_extension_operation( &mut self, - operation: ArithmeticOperation, + operation: ExtensionArithmeticOperation, ) -> ExtensionTarget { let (gate, i) = self.find_arithmetic_gate(operation.const_0, operation.const_1); let wires_multiplicand_0 = ExtensionTarget::from_range( @@ -99,6 +79,22 @@ impl, const D: usize> CircuitBuilder { ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_ith_output(i)) } + fn compute_mul_extension_operation( + &mut self, + operation: ExtensionArithmeticOperation, + ) -> ExtensionTarget { + let (gate, i) = self.find_mul_gate(operation.const_0); + let wires_multiplicand_0 = + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_multiplicand_0(i)); + let wires_multiplicand_1 = + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_multiplicand_1(i)); + + self.connect_extension(operation.multiplicand_0, wires_multiplicand_0); + self.connect_extension(operation.multiplicand_1, wires_multiplicand_1); + + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_output(i)) + } + /// Checks for special cases where the value of /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` /// can be determined without adding an `ArithmeticGate`. @@ -302,11 +298,11 @@ impl, const D: usize> CircuitBuilder { /// Multiply `n` `ExtensionTarget`s. pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let mut product = self.one_extension(); - for &term in terms { - product = self.mul_extension(product, term); - } - product + terms + .iter() + .copied() + .reduce(|acc, t| self.mul_extension(acc, t)) + .unwrap_or_else(|| self.one_extension()) } /// Like `mul_add`, but for `ExtensionTarget`s. @@ -321,14 +317,14 @@ impl, const D: usize> CircuitBuilder { /// Like `add_const`, but for `ExtensionTarget`s. pub fn add_const_extension(&mut self, x: ExtensionTarget, c: F) -> ExtensionTarget { - let one = self.one_extension(); - self.arithmetic_extension(F::ONE, c, one, x, one) + let c = self.constant_extension(c.into()); + self.add_extension(x, c) } /// Like `mul_const`, but for `ExtensionTarget`s. pub fn mul_const_extension(&mut self, c: F, x: ExtensionTarget) -> ExtensionTarget { - let zero = self.zero_extension(); - self.mul_const_add_extension(c, x, zero) + let c = self.constant_extension(c.into()); + self.mul_extension(c, x) } /// Like `mul_const_add`, but for `ExtensionTarget`s. @@ -338,8 +334,8 @@ impl, const D: usize> CircuitBuilder { x: ExtensionTarget, y: ExtensionTarget, ) -> ExtensionTarget { - let one = self.one_extension(); - self.arithmetic_extension(c, F::ONE, x, one, y) + let c = self.constant_extension(c.into()); + self.mul_add_extension(c, x, y) } /// Like `mul_add`, but for `ExtensionTarget`s. @@ -544,9 +540,9 @@ impl, const D: usize> CircuitBuilder { } } -/// Represents an arithmetic operation in the circuit. Used to memoize results. +/// Represents an extension arithmetic operation in the circuit. Used to memoize results. #[derive(Copy, Clone, Eq, PartialEq, Hash)] -pub(crate) struct ArithmeticOperation, const D: usize> { +pub(crate) struct ExtensionArithmeticOperation, const D: usize> { const_0: F, const_1: F, multiplicand_0: ExtensionTarget, @@ -556,11 +552,11 @@ pub(crate) struct ArithmeticOperation, const D: us #[cfg(test)] mod tests { - use std::convert::TryInto; - use anyhow::Result; use crate::field::extension_field::algebra::ExtensionAlgebra; + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::extension_field::target::ExtensionAlgebraTarget; use crate::field::field_types::Field; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -623,9 +619,7 @@ mod tests { let yt = builder.constant_extension(y); let zt = builder.constant_extension(z); let comp_zt = builder.div_extension(xt, yt); - let comp_zt_unsafe = builder.div_extension(xt, yt); builder.connect_extension(zt, comp_zt); - builder.connect_extension(zt, comp_zt_unsafe); let data = builder.build::(); let proof = data.prove(pw)?; @@ -642,23 +636,29 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); + let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = FF::rand_vec(D); - let y = FF::rand_vec(D); - let xa = ExtensionAlgebra(x.try_into().unwrap()); - let ya = ExtensionAlgebra(y.try_into().unwrap()); - let za = xa * ya; - - let xt = builder.constant_ext_algebra(xa); - let yt = builder.constant_ext_algebra(ya); - let zt = builder.constant_ext_algebra(za); + let xt = + ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap()); + let yt = + ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap()); + let zt = + ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap()); let comp_zt = builder.mul_ext_algebra(xt, yt); for i in 0..D { builder.connect_extension(zt.0[i], comp_zt.0[i]); } + let x = ExtensionAlgebra::(FF::rand_arr()); + let y = ExtensionAlgebra::(FF::rand_arr()); + let z = x * y; + for i in 0..D { + pw.set_extension_target(xt.0[i], x.0[i]); + pw.set_extension_target(yt.0[i], y.0[i]); + pw.set_extension_target(zt.0[i], z.0[i]); + } + let data = builder.build::(); let proof = data.prove(pw)?; diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs new file mode 100644 index 00000000..3bf6ce58 --- /dev/null +++ b/src/gadgets/arithmetic_u32.rs @@ -0,0 +1,154 @@ +use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; +use crate::gates::arithmetic_u32::U32ArithmeticGate; +use crate::gates::subtraction_u32::U32SubtractionGate; +use crate::iop::target::Target; +use crate::plonk::circuit_builder::CircuitBuilder; + +#[derive(Clone, Copy, Debug)] +pub struct U32Target(pub Target); + +impl, const D: usize> CircuitBuilder { + pub fn add_virtual_u32_target(&mut self) -> U32Target { + U32Target(self.add_virtual_target()) + } + + pub fn add_virtual_u32_targets(&mut self, n: usize) -> Vec { + self.add_virtual_targets(n) + .into_iter() + .map(U32Target) + .collect() + } + + pub fn zero_u32(&mut self) -> U32Target { + U32Target(self.zero()) + } + + pub fn one_u32(&mut self) -> U32Target { + U32Target(self.one()) + } + + pub fn connect_u32(&mut self, x: U32Target, y: U32Target) { + self.connect(x.0, y.0) + } + + pub fn assert_zero_u32(&mut self, x: U32Target) { + self.assert_zero(x.0) + } + + /// Checks for special cases where the value of + /// `x * y + z` + /// can be determined without adding a `U32ArithmeticGate`. + pub fn arithmetic_u32_special_cases( + &mut self, + x: U32Target, + y: U32Target, + z: U32Target, + ) -> Option<(U32Target, U32Target)> { + let x_const = self.target_as_constant(x.0); + let y_const = self.target_as_constant(y.0); + let z_const = self.target_as_constant(z.0); + + // If both terms are constant, return their (constant) sum. + let first_term_const = if let (Some(xx), Some(yy)) = (x_const, y_const) { + Some(xx * yy) + } else { + None + }; + + if let (Some(a), Some(b)) = (first_term_const, z_const) { + let sum = (a + b).to_canonical_u64(); + let (low, high) = (sum as u32, (sum >> 32) as u32); + return Some((self.constant_u32(low), self.constant_u32(high))); + } + + None + } + + // Returns x * y + z. + pub fn mul_add_u32( + &mut self, + x: U32Target, + y: U32Target, + z: U32Target, + ) -> (U32Target, U32Target) { + if let Some(result) = self.arithmetic_u32_special_cases(x, y, z) { + return result; + } + + let gate = U32ArithmeticGate::::new_from_config(&self.config); + let (gate_index, copy) = self.find_u32_arithmetic_gate(); + + self.connect( + Target::wire(gate_index, gate.wire_ith_multiplicand_0(copy)), + x.0, + ); + self.connect( + Target::wire(gate_index, gate.wire_ith_multiplicand_1(copy)), + y.0, + ); + self.connect(Target::wire(gate_index, gate.wire_ith_addend(copy)), z.0); + + let output_low = U32Target(Target::wire( + gate_index, + gate.wire_ith_output_low_half(copy), + )); + let output_high = U32Target(Target::wire( + gate_index, + gate.wire_ith_output_high_half(copy), + )); + + (output_low, output_high) + } + + pub fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { + let one = self.one_u32(); + self.mul_add_u32(a, one, b) + } + + pub fn add_many_u32(&mut self, to_add: &[U32Target]) -> (U32Target, U32Target) { + match to_add.len() { + 0 => (self.zero_u32(), self.zero_u32()), + 1 => (to_add[0], self.zero_u32()), + 2 => self.add_u32(to_add[0], to_add[1]), + _ => { + let (mut low, mut carry) = self.add_u32(to_add[0], to_add[1]); + for i in 2..to_add.len() { + let (new_low, new_carry) = self.add_u32(to_add[i], low); + let (combined_carry, _zero) = self.add_u32(carry, new_carry); + low = new_low; + carry = combined_carry; + } + (low, carry) + } + } + } + + pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { + let zero = self.zero_u32(); + self.mul_add_u32(a, b, zero) + } + + // Returns x - y - borrow, as a pair (result, borrow), where borrow is 0 or 1 depending on whether borrowing from the next digit is required (iff y + borrow > x). + pub fn sub_u32( + &mut self, + x: U32Target, + y: U32Target, + borrow: U32Target, + ) -> (U32Target, U32Target) { + let gate = U32SubtractionGate::::new_from_config(&self.config); + let (gate_index, copy) = self.find_u32_subtraction_gate(); + + self.connect(Target::wire(gate_index, gate.wire_ith_input_x(copy)), x.0); + self.connect(Target::wire(gate_index, gate.wire_ith_input_y(copy)), y.0); + self.connect( + Target::wire(gate_index, gate.wire_ith_input_borrow(copy)), + borrow.0, + ); + + let output_result = U32Target(Target::wire(gate_index, gate.wire_ith_output_result(copy))); + let output_borrow = U32Target(Target::wire(gate_index, gate.wire_ith_output_borrow(copy))); + + (output_result, output_borrow) + } +} diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs new file mode 100644 index 00000000..e037c402 --- /dev/null +++ b/src/gadgets/biguint.rs @@ -0,0 +1,395 @@ +use std::marker::PhantomData; + +use num::{BigUint, Integer}; + +use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; +use crate::iop::target::{BoolTarget, Target}; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; + +#[derive(Clone, Debug)] +pub struct BigUintTarget { + pub limbs: Vec, +} + +impl BigUintTarget { + pub fn num_limbs(&self) -> usize { + self.limbs.len() + } + + pub fn get_limb(&self, i: usize) -> U32Target { + self.limbs[i] + } +} + +impl, const D: usize> CircuitBuilder { + pub fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { + let limb_values = value.to_u32_digits(); + let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect(); + + BigUintTarget { limbs } + } + + pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { + let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); + for i in 0..min_limbs { + self.connect_u32(lhs.get_limb(i), rhs.get_limb(i)); + } + + for i in min_limbs..lhs.num_limbs() { + self.assert_zero_u32(lhs.get_limb(i)); + } + for i in min_limbs..rhs.num_limbs() { + self.assert_zero_u32(rhs.get_limb(i)); + } + } + + pub fn pad_biguints( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget) { + if a.num_limbs() > b.num_limbs() { + let mut padded_b = b.clone(); + for _ in b.num_limbs()..a.num_limbs() { + padded_b.limbs.push(self.zero_u32()); + } + + (a.clone(), padded_b) + } else { + let mut padded_a = a.clone(); + for _ in a.num_limbs()..b.num_limbs() { + padded_a.limbs.push(self.zero_u32()); + } + + (padded_a, b.clone()) + } + } + + pub fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { + let (a, b) = self.pad_biguints(a, b); + + self.list_le_u32(a.limbs, b.limbs) + } + + pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { + let limbs = (0..num_limbs) + .map(|_| self.add_virtual_u32_target()) + .collect(); + + BigUintTarget { limbs } + } + + // Add two `BigUintTarget`s. + pub fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let num_limbs = a.num_limbs().max(b.num_limbs()); + + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for i in 0..num_limbs { + let a_limb = (i < a.num_limbs()) + .then(|| a.limbs[i]) + .unwrap_or_else(|| self.zero_u32()); + let b_limb = (i < b.num_limbs()) + .then(|| b.limbs[i]) + .unwrap_or_else(|| self.zero_u32()); + + let (new_limb, new_carry) = self.add_many_u32(&[carry, a_limb, b_limb]); + carry = new_carry; + combined_limbs.push(new_limb); + } + combined_limbs.push(carry); + + BigUintTarget { + limbs: combined_limbs, + } + } + + // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. + pub fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let (a, b) = self.pad_biguints(a, b); + let num_limbs = a.limbs.len(); + + let mut result_limbs = vec![]; + + let mut borrow = self.zero_u32(); + for i in 0..num_limbs { + let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); + result_limbs.push(result); + borrow = new_borrow; + } + // Borrow should be zero here. + + BigUintTarget { + limbs: result_limbs, + } + } + + pub fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let total_limbs = a.limbs.len() + b.limbs.len(); + + let mut to_add = vec![vec![]; total_limbs]; + for i in 0..a.limbs.len() { + for j in 0..b.limbs.len() { + let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); + to_add[i + j].push(product); + to_add[i + j + 1].push(carry); + } + } + + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for summands in &mut to_add { + summands.push(carry); + let (new_result, new_carry) = self.add_many_u32(summands); + combined_limbs.push(new_result); + carry = new_carry; + } + combined_limbs.push(carry); + + BigUintTarget { + limbs: combined_limbs, + } + } + + // Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). + pub fn mul_add_biguint( + &mut self, + x: &BigUintTarget, + y: &BigUintTarget, + z: &BigUintTarget, + ) -> BigUintTarget { + let prod = self.mul_biguint(x, y); + self.add_biguint(&prod, z) + } + + pub fn div_rem_biguint( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget) { + let a_len = a.limbs.len(); + let b_len = b.limbs.len(); + let div_num_limbs = if b_len > a_len + 1 { + 0 + } else { + a_len - b_len + 1 + }; + let div = self.add_virtual_biguint_target(div_num_limbs); + let rem = self.add_virtual_biguint_target(b_len); + + self.add_simple_generator(BigUintDivRemGenerator:: { + a: a.clone(), + b: b.clone(), + div: div.clone(), + rem: rem.clone(), + _phantom: PhantomData, + }); + + let div_b = self.mul_biguint(&div, b); + let div_b_plus_rem = self.add_biguint(&div_b, &rem); + self.connect_biguint(a, &div_b_plus_rem); + + let cmp_rem_b = self.cmp_biguint(&rem, b); + self.assert_one(cmp_rem_b.target); + + (div, rem) + } + + pub fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let (div, _rem) = self.div_rem_biguint(a, b); + div + } + + pub fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let (_div, rem) = self.div_rem_biguint(a, b); + rem + } +} + +#[derive(Debug)] +struct BigUintDivRemGenerator, const D: usize> { + a: BigUintTarget, + b: BigUintTarget, + div: BigUintTarget, + rem: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for BigUintDivRemGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .limbs + .iter() + .chain(&self.b.limbs) + .map(|&l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_biguint_target(self.a.clone()); + let b = witness.get_biguint_target(self.b.clone()); + let (div, rem) = a.div_rem(&b); + + out_buffer.set_biguint_target(self.div.clone(), div); + out_buffer.set_biguint_target(self.rem.clone(), rem); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use num::{BigUint, FromPrimitive, Integer}; + use rand::Rng; + + use crate::iop::witness::Witness; + use crate::{ + field::goldilocks_field::GoldilocksField, + iop::witness::PartialWitness, + plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify}, + }; + + #[test] + fn test_biguint_add() -> Result<()> { + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); + let expected_z_value = &x_value + &y_value; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); + let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); + let z = builder.add_biguint(&x, &y); + let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); + builder.connect_biguint(&z, &expected_z); + + pw.set_biguint_target(&x, &x_value); + pw.set_biguint_target(&y, &y_value); + pw.set_biguint_target(&expected_z, &expected_z_value); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_biguint_sub() -> Result<()> { + let mut rng = rand::thread_rng(); + + let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); + let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); + if y_value > x_value { + (x_value, y_value) = (y_value, x_value); + } + let expected_z_value = &x_value - &y_value; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let z = builder.sub_biguint(&x, &y); + let expected_z = builder.constant_biguint(&expected_z_value); + + builder.connect_biguint(&z, &expected_z); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_biguint_mul() -> Result<()> { + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); + let expected_z_value = &x_value * &y_value; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); + let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); + let z = builder.mul_biguint(&x, &y); + let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); + builder.connect_biguint(&z, &expected_z); + + pw.set_biguint_target(&x, &x_value); + pw.set_biguint_target(&y, &y_value); + pw.set_biguint_target(&expected_z, &expected_z_value); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_biguint_cmp() -> Result<()> { + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let cmp = builder.cmp_biguint(&x, &y); + let expected_cmp = builder.constant_bool(x_value <= y_value); + + builder.connect(cmp.target, expected_cmp.target); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_biguint_div_rem() -> Result<()> { + let mut rng = rand::thread_rng(); + + let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); + let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); + if y_value > x_value { + (x_value, y_value) = (y_value, x_value); + } + let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value); + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let (div, rem) = builder.div_rem_biguint(&x, &y); + + let expected_div = builder.constant_biguint(&expected_div_value); + let expected_rem = builder.constant_biguint(&expected_rem_value); + + builder.connect_biguint(&div, &expected_div); + builder.connect_biguint(&rem, &expected_rem); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs new file mode 100644 index 00000000..c86c3c0d --- /dev/null +++ b/src/gadgets/curve.rs @@ -0,0 +1,368 @@ +use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; +use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, RichField}; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::plonk::circuit_builder::CircuitBuilder; + +/// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency, +/// so we assume these points are not zero. +#[derive(Clone, Debug)] +pub struct AffinePointTarget { + pub x: NonNativeTarget, + pub y: NonNativeTarget, +} + +impl AffinePointTarget { + pub fn to_vec(&self) -> Vec> { + vec![self.x.clone(), self.y.clone()] + } +} + +impl, const D: usize> CircuitBuilder { + pub fn constant_affine_point( + &mut self, + point: AffinePoint, + ) -> AffinePointTarget { + debug_assert!(!point.zero); + AffinePointTarget { + x: self.constant_nonnative(point.x), + y: self.constant_nonnative(point.y), + } + } + + pub fn connect_affine_point( + &mut self, + lhs: &AffinePointTarget, + rhs: &AffinePointTarget, + ) { + self.connect_nonnative(&lhs.x, &rhs.x); + self.connect_nonnative(&lhs.y, &rhs.y); + } + + pub fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget { + let x = self.add_virtual_nonnative_target(); + let y = self.add_virtual_nonnative_target(); + + AffinePointTarget { x, y } + } + + pub fn curve_assert_valid(&mut self, p: &AffinePointTarget) { + let a = self.constant_nonnative(C::A); + let b = self.constant_nonnative(C::B); + + let y_squared = self.mul_nonnative(&p.y, &p.y); + let x_squared = self.mul_nonnative(&p.x, &p.x); + let x_cubed = self.mul_nonnative(&x_squared, &p.x); + let a_x = self.mul_nonnative(&a, &p.x); + let a_x_plus_b = self.add_nonnative(&a_x, &b); + let rhs = self.add_nonnative(&x_cubed, &a_x_plus_b); + + self.connect_nonnative(&y_squared, &rhs); + } + + pub fn curve_neg(&mut self, p: &AffinePointTarget) -> AffinePointTarget { + let neg_y = self.neg_nonnative(&p.y); + AffinePointTarget { + x: p.x.clone(), + y: neg_y, + } + } + + pub fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { + let AffinePointTarget { x, y } = p; + let double_y = self.add_nonnative(y, y); + let inv_double_y = self.inv_nonnative(&double_y); + let x_squared = self.mul_nonnative(x, x); + let double_x_squared = self.add_nonnative(&x_squared, &x_squared); + let triple_x_squared = self.add_nonnative(&double_x_squared, &x_squared); + + let a = self.constant_nonnative(C::A); + let triple_xx_a = self.add_nonnative(&triple_x_squared, &a); + let lambda = self.mul_nonnative(&triple_xx_a, &inv_double_y); + let lambda_squared = self.mul_nonnative(&lambda, &lambda); + let x_double = self.add_nonnative(x, x); + + let x3 = self.sub_nonnative(&lambda_squared, &x_double); + + let x_diff = self.sub_nonnative(x, &x3); + let lambda_x_diff = self.mul_nonnative(&lambda, &x_diff); + + let y3 = self.sub_nonnative(&lambda_x_diff, y); + + AffinePointTarget { x: x3, y: y3 } + } + + // Add two points, which are assumed to be non-equal. + pub fn curve_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget { + let AffinePointTarget { x: x1, y: y1 } = p1; + let AffinePointTarget { x: x2, y: y2 } = p2; + + let u = self.sub_nonnative(y2, y1); + let uu = self.mul_nonnative(&u, &u); + let v = self.sub_nonnative(x2, x1); + let vv = self.mul_nonnative(&v, &v); + let vvv = self.mul_nonnative(&v, &vv); + let r = self.mul_nonnative(&vv, x1); + let diff = self.sub_nonnative(&uu, &vvv); + let r2 = self.add_nonnative(&r, &r); + let a = self.sub_nonnative(&diff, &r2); + let x3 = self.mul_nonnative(&v, &a); + + let r_a = self.sub_nonnative(&r, &a); + let y3_first = self.mul_nonnative(&u, &r_a); + let y3_second = self.mul_nonnative(&vvv, y1); + let y3 = self.sub_nonnative(&y3_first, &y3_second); + + let z3_inv = self.inv_nonnative(&vvv); + let x3_norm = self.mul_nonnative(&x3, &z3_inv); + let y3_norm = self.mul_nonnative(&y3, &z3_inv); + + AffinePointTarget { + x: x3_norm, + y: y3_norm, + } + } + + pub fn curve_scalar_mul( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget { + let one = self.constant_nonnative(C::BaseField::ONE); + + let bits = self.split_nonnative_to_bits(n); + let bits_as_base: Vec> = + bits.iter().map(|b| self.bool_to_nonnative(b)).collect(); + + let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); + let randot = self.constant_affine_point(rando); + // Result starts at `rando`, which is later subtracted, because we don't support arithmetic with the zero point. + let mut result = self.add_virtual_affine_point_target(); + self.connect_affine_point(&randot, &result); + + let mut two_i_times_p = self.add_virtual_affine_point_target(); + self.connect_affine_point(p, &two_i_times_p); + + for bit in bits_as_base.iter() { + let not_bit = self.sub_nonnative(&one, bit); + + let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p); + + let new_x_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.x); + let new_x_if_not_bit = self.mul_nonnative(¬_bit, &result.x); + let new_y_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.y); + let new_y_if_not_bit = self.mul_nonnative(¬_bit, &result.y); + + let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit); + let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit); + + result = AffinePointTarget { x: new_x, y: new_y }; + + two_i_times_p = self.curve_double(&two_i_times_p); + } + + // Subtract off result's intial value of `rando`. + let neg_r = self.curve_neg(&randot); + result = self.curve_add(&result, &neg_r); + + result + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; + use crate::field::secp256k1_base::Secp256K1Base; + use crate::field::secp256k1_scalar::Secp256K1Scalar; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + #[test] + fn test_curve_point_is_valid() -> Result<()> { + type F = GoldilocksField; + const D: usize = 4; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let g_target = builder.constant_affine_point(g); + let neg_g_target = builder.curve_neg(&g_target); + + builder.curve_assert_valid(&g_target); + builder.curve_assert_valid(&neg_g_target); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + #[should_panic] + fn test_curve_point_is_not_valid() { + type F = GoldilocksField; + const D: usize = 4; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let not_g = AffinePoint:: { + x: g.x, + y: g.y + Secp256K1Base::ONE, + zero: g.zero, + }; + let not_g_target = builder.constant_affine_point(not_g); + + builder.curve_assert_valid(¬_g_target); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common).unwrap(); + } + + #[test] + fn test_curve_double() -> Result<()> { + type F = GoldilocksField; + const D: usize = 4; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let g_target = builder.constant_affine_point(g); + let neg_g_target = builder.curve_neg(&g_target); + + let double_g = g.double(); + let double_g_expected = builder.constant_affine_point(double_g); + builder.curve_assert_valid(&double_g_expected); + + let double_neg_g = (-g).double(); + let double_neg_g_expected = builder.constant_affine_point(double_neg_g); + builder.curve_assert_valid(&double_neg_g_expected); + + let double_g_actual = builder.curve_double(&g_target); + let double_neg_g_actual = builder.curve_double(&neg_g_target); + builder.curve_assert_valid(&double_g_actual); + builder.curve_assert_valid(&double_neg_g_actual); + + builder.connect_affine_point(&double_g_expected, &double_g_actual); + builder.connect_affine_point(&double_neg_g_expected, &double_neg_g_actual); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_curve_add() -> Result<()> { + type F = GoldilocksField; + const D: usize = 4; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let double_g = g.double(); + let g_plus_2g = (g + double_g).to_affine(); + let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); + builder.curve_assert_valid(&g_plus_2g_expected); + + let g_target = builder.constant_affine_point(g); + let double_g_target = builder.curve_double(&g_target); + let g_plus_2g_actual = builder.curve_add(&g_target, &double_g_target); + builder.curve_assert_valid(&g_plus_2g_actual); + + builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + #[ignore] + fn test_curve_mul() -> Result<()> { + type F = GoldilocksField; + const D: usize = 4; + + let config = CircuitConfig { + num_routed_wires: 33, + ..CircuitConfig::standard_recursion_config() + }; + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let five = Secp256K1Scalar::from_canonical_usize(5); + let five_scalar = CurveScalar::(five); + let five_g = (five_scalar * g.to_projective()).to_affine(); + let five_g_expected = builder.constant_affine_point(five_g); + builder.curve_assert_valid(&five_g_expected); + + let g_target = builder.constant_affine_point(g); + let five_target = builder.constant_nonnative(five); + let five_g_actual = builder.curve_scalar_mul(&g_target, &five_target); + builder.curve_assert_valid(&five_g_actual); + + builder.connect_affine_point(&five_g_expected, &five_g_actual); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + #[ignore] + fn test_curve_random() -> Result<()> { + type F = GoldilocksField; + const D: usize = 4; + + let config = CircuitConfig { + num_routed_wires: 33, + ..CircuitConfig::standard_recursion_config() + }; + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let rando = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let randot = builder.constant_affine_point(rando); + + let two_target = builder.constant_nonnative(Secp256K1Scalar::TWO); + let randot_doubled = builder.curve_double(&randot); + let randot_times_two = builder.curve_scalar_mul(&randot, &two_target); + builder.connect_affine_point(&randot_doubled, &randot_times_two); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 4206e810..4081404c 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -1,22 +1,94 @@ +use std::ops::Range; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; +use crate::gates::gate::Gate; use crate::gates::interpolation::InterpolationGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; -impl, const D: usize> CircuitBuilder { +/// Trait for gates which interpolate a polynomial, whose points are a (base field) coset of the multiplicative subgroup +/// with the given size, and whose values are extension field elements, given by input wires. +/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. +pub(crate) trait InterpolationGate, const D: usize>: + Gate + Copy +{ + fn new(subgroup_bits: usize) -> Self; + + fn num_points(&self) -> usize; + + /// Wire index of the coset shift. + fn wire_shift(&self) -> usize { + 0 + } + + fn start_values(&self) -> usize { + 1 + } + + /// Wire indices of the `i`th interpolant value. + fn wires_value(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_values() + i * D; + start..start + D + } + + fn start_evaluation_point(&self) -> usize { + self.start_values() + self.num_points() * D + } + + /// Wire indices of the point to evaluate the interpolant at. + fn wires_evaluation_point(&self) -> Range { + let start = self.start_evaluation_point(); + start..start + D + } + + fn start_evaluation_value(&self) -> usize { + self.start_evaluation_point() + D + } + + /// Wire indices of the interpolated value. + fn wires_evaluation_value(&self) -> Range { + let start = self.start_evaluation_value(); + start..start + D + } + + fn start_coeffs(&self) -> usize { + self.start_evaluation_value() + D + } + + /// The number of routed wires required in the typical usage of this gate, where the points to + /// interpolate, the evaluation point, and the corresponding value are all routed. + fn num_routed_wires(&self) -> usize { + self.start_coeffs() + } + + /// Wire indices of the interpolant's `i`th coefficient. + fn wires_coeff(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_coeffs() + i * D; + start..start + D + } + + fn end_coeffs(&self) -> usize { + self.start_coeffs() + D * self.num_points() + } +} + +impl, const D: usize> CircuitBuilder { /// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the /// given size, and whose values are given. Returns the evaluation of the interpolant at /// `evaluation_point`. - pub fn interpolate_coset( + pub(crate) fn interpolate_coset>( &mut self, subgroup_bits: usize, coset_shift: Target, values: &[ExtensionTarget], evaluation_point: ExtensionTarget, ) -> ExtensionTarget { - let gate = InterpolationGate::new(subgroup_bits); - let gate_index = self.add_gate(gate.clone(), vec![]); + let gate = G::new(subgroup_bits); + let gate_index = self.add_gate(gate, vec![]); self.connect(coset_shift, Target::wire(gate_index, gate.wire_shift())); for (i, &v) in values.iter().enumerate() { self.connect_extension( @@ -37,6 +109,7 @@ impl, const D: usize> CircuitBuilder { mod tests { use anyhow::Result; + use crate::field::extension_field::quadratic::QuadraticExtension; use crate::field::extension_field::FieldExtension; use crate::field::field_types::Field; use crate::field::interpolation::interpolant; @@ -83,9 +156,21 @@ mod tests { let zt = builder.constant_extension(z); - let eval = builder.interpolate_coset(subgroup_bits, coset_shift_target, &value_targets, zt); + let eval_hd = builder.interpolate_coset::>( + subgroup_bits, + coset_shift_target, + &value_targets, + zt, + ); + let eval_ld = builder.interpolate_coset::>( + subgroup_bits, + coset_shift_target, + &value_targets, + zt, + ); let true_eval_target = builder.constant_extension(true_eval); - builder.connect_extension(eval, true_eval_target); + builder.connect_extension(eval_hd, true_eval_target); + builder.connect_extension(eval_ld, true_eval_target); let data = builder.build::(); let proof = data.prove(pw)?; diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index aa18fbeb..09acb9de 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -1,8 +1,13 @@ pub mod arithmetic; pub mod arithmetic_extension; +pub mod arithmetic_u32; +pub mod biguint; +pub mod curve; pub mod hash; pub mod insert; pub mod interpolation; +pub mod multiple_comparison; +pub mod nonnative; pub mod permutation; pub mod polynomial; pub mod random_access; diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs new file mode 100644 index 00000000..3a5f2421 --- /dev/null +++ b/src/gadgets/multiple_comparison.rs @@ -0,0 +1,138 @@ +use super::arithmetic_u32::U32Target; +use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; +use crate::gates::comparison::ComparisonGate; +use crate::iop::target::{BoolTarget, Target}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::ceil_div_usize; + +impl, const D: usize> CircuitBuilder { + /// Returns true if a is less than or equal to b, considered as base-`2^num_bits` limbs of a large value. + /// This range-checks its inputs. + pub fn list_le(&mut self, a: Vec, b: Vec, num_bits: usize) -> BoolTarget { + assert_eq!( + a.len(), + b.len(), + "Comparison must be between same number of inputs and outputs" + ); + let n = a.len(); + + let chunk_bits = 2; + let num_chunks = ceil_div_usize(num_bits, chunk_bits); + + let one = self.one(); + let mut result = one; + for i in 0..n { + let a_le_b_gate = ComparisonGate::new(num_bits, num_chunks); + let a_le_b_gate_index = self.add_gate(a_le_b_gate.clone(), vec![]); + self.connect( + Target::wire(a_le_b_gate_index, a_le_b_gate.wire_first_input()), + a[i], + ); + self.connect( + Target::wire(a_le_b_gate_index, a_le_b_gate.wire_second_input()), + b[i], + ); + let a_le_b_result = Target::wire(a_le_b_gate_index, a_le_b_gate.wire_result_bool()); + + let b_le_a_gate = ComparisonGate::new(num_bits, num_chunks); + let b_le_a_gate_index = self.add_gate(b_le_a_gate.clone(), vec![]); + self.connect( + Target::wire(b_le_a_gate_index, b_le_a_gate.wire_first_input()), + b[i], + ); + self.connect( + Target::wire(b_le_a_gate_index, b_le_a_gate.wire_second_input()), + a[i], + ); + let b_le_a_result = Target::wire(b_le_a_gate_index, b_le_a_gate.wire_result_bool()); + + let these_limbs_equal = self.mul(a_le_b_result, b_le_a_result); + let these_limbs_less_than = self.sub(one, b_le_a_result); + result = self.mul_add(these_limbs_equal, result, these_limbs_less_than); + } + + // `result` being boolean is an invariant, maintained because its new value is always + // `x * result + y`, where `x` and `y` are booleans that are not simultaneously true. + BoolTarget::new_unsafe(result) + } + + /// Helper function for comparing, specifically, lists of `U32Target`s. + pub fn list_le_u32(&mut self, a: Vec, b: Vec) -> BoolTarget { + let a_targets = a.iter().map(|&t| t.0).collect(); + let b_targets = b.iter().map(|&t| t.0).collect(); + self.list_le(a_targets, b_targets, 32) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use num::BigUint; + use rand::Rng; + + use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + fn test_list_le(size: usize, num_bits: usize) -> Result<()> { + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let mut rng = rand::thread_rng(); + + let lst1: Vec = (0..size) + .map(|_| rng.gen_range(0..(1 << num_bits))) + .collect(); + let lst2: Vec = (0..size) + .map(|_| rng.gen_range(0..(1 << num_bits))) + .collect(); + + let a_biguint = BigUint::from_slice( + &lst1 + .iter() + .flat_map(|&x| [x as u32, (x >> 32) as u32]) + .collect::>(), + ); + let b_biguint = BigUint::from_slice( + &lst2 + .iter() + .flat_map(|&x| [x as u32, (x >> 32) as u32]) + .collect::>(), + ); + + let a = lst1 + .iter() + .map(|&x| builder.constant(F::from_canonical_u64(x))) + .collect(); + let b = lst2 + .iter() + .map(|&x| builder.constant(F::from_canonical_u64(x))) + .collect(); + + let result = builder.list_le(a, b, num_bits); + + let expected_result = builder.constant_bool(a_biguint <= b_biguint); + builder.connect(result.target, expected_result.target); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_multiple_comparison() -> Result<()> { + for size in [1, 3, 6] { + for num_bits in [20, 32, 40, 44] { + test_list_le(size, num_bits).unwrap(); + } + } + + Ok(()) + } +} diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs new file mode 100644 index 00000000..56d717e3 --- /dev/null +++ b/src/gadgets/nonnative.rs @@ -0,0 +1,342 @@ +use std::marker::PhantomData; + +use num::{BigUint, Zero}; + +use crate::field::field_types::RichField; +use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::biguint::BigUintTarget; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; +use crate::iop::target::{BoolTarget, Target}; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::ceil_div_usize; + +#[derive(Clone, Debug)] +pub struct NonNativeTarget { + pub(crate) value: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize> CircuitBuilder { + fn num_nonnative_limbs() -> usize { + ceil_div_usize(FF::BITS, 32) + } + + pub fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget { + NonNativeTarget { + value: x.clone(), + _phantom: PhantomData, + } + } + + pub fn nonnative_to_biguint(&mut self, x: &NonNativeTarget) -> BigUintTarget { + x.value.clone() + } + + pub fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget { + let x_biguint = self.constant_biguint(&x.to_biguint()); + self.biguint_to_nonnative(&x_biguint) + } + + // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. + pub fn connect_nonnative( + &mut self, + lhs: &NonNativeTarget, + rhs: &NonNativeTarget, + ) { + self.connect_biguint(&lhs.value, &rhs.value); + } + + pub fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget { + let num_limbs = Self::num_nonnative_limbs::(); + let value = self.add_virtual_biguint_target(num_limbs); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + // Add two `NonNativeTarget`s. + pub fn add_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget { + let result = self.add_biguint(&a.value, &b.value); + + // TODO: reduce add result with only one conditional subtraction + self.reduce(&result) + } + + // Subtract two `NonNativeTarget`s. + pub fn sub_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget { + let order = self.constant_biguint(&FF::order()); + let a_plus_order = self.add_biguint(&order, &a.value); + let result = self.sub_biguint(&a_plus_order, &b.value); + + // TODO: reduce sub result with only one conditional addition? + self.reduce(&result) + } + + pub fn mul_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget { + let result = self.mul_biguint(&a.value, &b.value); + + self.reduce(&result) + } + + pub fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + let zero_target = self.constant_biguint(&BigUint::zero()); + let zero_ff = self.biguint_to_nonnative(&zero_target); + + self.sub_nonnative(&zero_ff, x) + } + + pub fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + let num_limbs = x.value.num_limbs(); + let inv_biguint = self.add_virtual_biguint_target(num_limbs); + let inv = NonNativeTarget:: { + value: inv_biguint, + _phantom: PhantomData, + }; + + self.add_simple_generator(NonNativeInverseGenerator:: { + x: x.clone(), + inv: inv.clone(), + _phantom: PhantomData, + }); + + let product = self.mul_nonnative(x, &inv); + let one = self.constant_nonnative(FF::ONE); + self.connect_nonnative(&product, &one); + + inv + } + + pub fn div_rem_nonnative( + &mut self, + x: &NonNativeTarget, + y: &NonNativeTarget, + ) -> (NonNativeTarget, NonNativeTarget) { + let x_biguint = self.nonnative_to_biguint(x); + let y_biguint = self.nonnative_to_biguint(y); + + let (div_biguint, rem_biguint) = self.div_rem_biguint(&x_biguint, &y_biguint); + let div = self.biguint_to_nonnative(&div_biguint); + let rem = self.biguint_to_nonnative(&rem_biguint); + (div, rem) + } + + /// Returns `x % |FF|` as a `NonNativeTarget`. + fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget { + let modulus = FF::order(); + let order_target = self.constant_biguint(&modulus); + let value = self.rem_biguint(x, &order_target); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + #[allow(dead_code)] + fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + let x_biguint = self.nonnative_to_biguint(x); + self.reduce(&x_biguint) + } + + pub fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget { + let limbs = vec![U32Target(b.target)]; + let value = BigUintTarget { limbs }; + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + // Split a nonnative field element to bits. + pub fn split_nonnative_to_bits( + &mut self, + x: &NonNativeTarget, + ) -> Vec { + let num_limbs = x.value.num_limbs(); + let mut result = Vec::with_capacity(num_limbs * 32); + + for i in 0..num_limbs { + let limb = x.value.get_limb(i); + let bit_targets = self.split_le_base::<2>(limb.0, 32); + let mut bits: Vec<_> = bit_targets + .iter() + .map(|&t| BoolTarget::new_unsafe(t)) + .collect(); + + result.append(&mut bits); + } + + result + } +} + +#[derive(Debug)] +struct NonNativeInverseGenerator, const D: usize, FF: Field> { + x: NonNativeTarget, + inv: NonNativeTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: Field> SimpleGenerator + for NonNativeInverseGenerator +{ + fn dependencies(&self) -> Vec { + self.x.value.limbs.iter().map(|&l| l.0).collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let x = witness.get_nonnative_target(self.x.clone()); + let inv = x.inverse(); + + out_buffer.set_nonnative_target(self.inv.clone(), inv); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; + use crate::field::secp256k1_base::Secp256K1Base; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + #[test] + fn test_nonnative_add() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let y_ff = FF::rand(); + let sum_ff = x_ff + y_ff; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); + let sum = builder.add_nonnative(&x, &y); + + let sum_expected = builder.constant_nonnative(sum_ff); + builder.connect_nonnative(&sum, &sum_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_nonnative_sub() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let mut y_ff = FF::rand(); + while y_ff.to_biguint() > x_ff.to_biguint() { + y_ff = FF::rand(); + } + let diff_ff = x_ff - y_ff; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); + let diff = builder.sub_nonnative(&x, &y); + + let diff_expected = builder.constant_nonnative(diff_ff); + builder.connect_nonnative(&diff, &diff_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_nonnative_mul() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let y_ff = FF::rand(); + let product_ff = x_ff * y_ff; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); + let product = builder.mul_nonnative(&x, &y); + + let product_expected = builder.constant_nonnative(product_ff); + builder.connect_nonnative(&product, &product_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_nonnative_neg() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let neg_x_ff = -x_ff; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let neg_x = builder.neg_nonnative(&x); + + let neg_x_expected = builder.constant_nonnative(neg_x_ff); + builder.connect_nonnative(&neg_x, &neg_x_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_nonnative_inv() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let inv_x_ff = x_ff.inverse(); + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let inv_x = builder.inv_nonnative(&x); + + let inv_x_expected = builder.constant_nonnative(inv_x_ff); + builder.connect_nonnative(&inv_x, &inv_x_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index 644e17a7..9ded28ab 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -2,7 +2,6 @@ use std::collections::BTreeMap; use std::marker::PhantomData; use crate::field::{extension_field::Extendable, field_types::Field}; -use crate::gates::switch::SwitchGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; @@ -34,7 +33,6 @@ impl, const D: usize> CircuitBuilder { self.assert_permutation_2x2(a[0].clone(), a[1].clone(), b[0].clone(), b[1].clone()) } // For larger lists, we recursively use two smaller permutation networks. - //_ => self.assert_permutation_recursive(a, b) _ => self.assert_permutation_recursive(a, b), } } @@ -72,22 +70,7 @@ impl, const D: usize> CircuitBuilder { let chunk_size = a1.len(); - if self.current_switch_gates.len() < chunk_size { - self.current_switch_gates - .extend(vec![None; chunk_size - self.current_switch_gates.len()]); - } - - let (gate, gate_index, mut next_copy) = - match self.current_switch_gates[chunk_size - 1].clone() { - None => { - let gate = SwitchGate::::new_from_config(&self.config, chunk_size); - let gate_index = self.add_gate(gate.clone(), vec![]); - (gate, gate_index, 0) - } - Some((gate, idx, next_copy)) => (gate, idx, next_copy), - }; - - let num_copies = gate.num_copies; + let (gate, gate_index, next_copy) = self.find_switch_gate(chunk_size); let mut c = Vec::new(); let mut d = Vec::new(); @@ -112,13 +95,6 @@ impl, const D: usize> CircuitBuilder { let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy)); - next_copy += 1; - if next_copy == num_copies { - self.current_switch_gates[chunk_size - 1] = None; - } else { - self.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy)); - } - (switch, c, d) } @@ -402,7 +378,7 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let lst: Vec = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect(); + let lst: Vec = (0..size * 2).map(F::from_canonical_usize).collect(); let a: Vec> = lst[..] .chunks(2) .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) diff --git a/src/gadgets/polynomial.rs b/src/gadgets/polynomial.rs index 3d371c53..9e0229fe 100644 --- a/src/gadgets/polynomial.rs +++ b/src/gadgets/polynomial.rs @@ -63,4 +63,21 @@ impl PolynomialCoeffsExtAlgebraTarget { } acc } + + /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. + pub fn eval_with_powers( + &self, + builder: &mut CircuitBuilder, + powers: &[ExtensionAlgebraTarget], + ) -> ExtensionAlgebraTarget + where + F: RichField + Extendable, + { + debug_assert_eq!(self.0.len(), powers.len() + 1); + let acc = self.0[0]; + self.0[1..] + .iter() + .zip(powers) + .fold(acc, |acc, (&x, &c)| builder.mul_add_ext_algebra(c, x, acc)) + } } diff --git a/src/gadgets/random_access.rs b/src/gadgets/random_access.rs index 50f3ac76..c9f150f8 100644 --- a/src/gadgets/random_access.rs +++ b/src/gadgets/random_access.rs @@ -3,49 +3,20 @@ use crate::field::extension_field::Extendable; use crate::gates::random_access::RandomAccessGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::log2_strict; impl, const D: usize> CircuitBuilder { - /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. - /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index - /// `g` and the gate's `i`-th random access is available. - fn find_random_access_gate(&mut self, vec_size: usize) -> (usize, usize) { - let (gate, i) = self - .free_random_access - .get(&vec_size) - .copied() - .unwrap_or_else(|| { - let gate = self.add_gate( - RandomAccessGate::new_from_config(&self.config, vec_size), - vec![], - ); - (gate, 0) - }); - - // Update `free_random_access` with new values. - if i < RandomAccessGate::::max_num_copies( - self.config.num_routed_wires, - self.config.num_wires, - vec_size, - ) - 1 - { - self.free_random_access.insert(vec_size, (gate, i + 1)); - } else { - self.free_random_access.remove(&vec_size); - } - - (gate, i) - } - /// Checks that a `Target` matches a vector at a non-deterministic index. /// Note: `access_index` is not range-checked. pub fn random_access(&mut self, access_index: Target, claimed_element: Target, v: Vec) { let vec_size = v.len(); + let bits = log2_strict(vec_size); debug_assert!(vec_size > 0); if vec_size == 1 { return self.connect(claimed_element, v[0]); } - let (gate_index, copy) = self.find_random_access_gate(vec_size); - let dummy_gate = RandomAccessGate::::new_from_config(&self.config, vec_size); + let (gate_index, copy) = self.find_random_access_gate(bits); + let dummy_gate = RandomAccessGate::::new_from_config(&self.config, bits); v.iter().enumerate().for_each(|(i, &val)| { self.connect( diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 95560f92..f4754d75 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -3,6 +3,8 @@ use std::marker::PhantomData; use itertools::izip; use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, RichField}; +use crate::gates::assert_le::AssertLessThanGate; use crate::field::field_types::Field; use crate::gates::comparison::ComparisonGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; @@ -40,9 +42,9 @@ impl, const D: usize> CircuitBuilder { self.assert_permutation(a_chunks, b_chunks); } - /// Add a ComparisonGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits. + /// Add an AssertLessThanGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits. pub fn assert_le(&mut self, lhs: Target, rhs: Target, bits: usize, num_chunks: usize) { - let gate = ComparisonGate::new(bits, num_chunks); + let gate = AssertLessThanGate::new(bits, num_chunks); let gate_index = self.add_gate(gate.clone(), vec![]); self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs); @@ -126,8 +128,7 @@ impl, const D: usize> SimpleGenerator for MemoryOpSortGenera fn dependencies(&self) -> Vec { self.input_ops .iter() - .map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value]) - .flatten() + .flat_map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value]) .collect() } @@ -223,7 +224,7 @@ mod tests { izip!(is_write_vals, address_vals, timestamp_vals, value_vals) .zip(combined_vals_u64) .collect::>(); - input_ops_and_keys.sort_by_key(|(_, val)| val.clone()); + input_ops_and_keys.sort_by_key(|(_, val)| *val); let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(x, _)| x).collect(); let output_ops = diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index e35a0d20..1105710c 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -1,5 +1,7 @@ use std::borrow::Borrow; +use itertools::Itertools; + use crate::field::extension_field::Extendable; use crate::field::field_types::Field; use crate::gates::base_sum::BaseSumGate; @@ -27,23 +29,25 @@ impl, const D: usize> CircuitBuilder { /// Takes an iterator of bits `(b_i)` and returns `sum b_i * 2^i`, i.e., /// the number with little-endian bit representation given by `bits`. - pub(crate) fn le_sum( - &mut self, - bits: impl ExactSizeIterator> + Clone, - ) -> Target { + pub(crate) fn le_sum(&mut self, bits: impl Iterator>) -> Target { + let bits = bits.map(|b| *b.borrow()).collect_vec(); let num_bits = bits.len(); if num_bits == 0 { return self.zero(); - } else if num_bits == 1 { - let mut bits = bits; - return bits.next().unwrap().borrow().target; - } else if num_bits == 2 { - let two = self.two(); - let mut bits = bits; - let b0 = bits.next().unwrap().borrow().target; - let b1 = bits.next().unwrap().borrow().target; - return self.mul_add(two, b1, b0); } + + // Check if it's cheaper to just do this with arithmetic operations. + let arithmetic_ops = num_bits - 1; + if arithmetic_ops <= self.num_base_arithmetic_ops_per_gate() { + let two = self.two(); + let mut rev_bits = bits.iter().rev(); + let mut sum = rev_bits.next().unwrap().target; + for &bit in rev_bits { + sum = self.mul_add(two, sum, bit.target); + } + return sum; + } + debug_assert!( BaseSumGate::<2>::START_LIMBS + num_bits <= self.config.num_routed_wires, "Not enough routed wires." @@ -51,10 +55,10 @@ impl, const D: usize> CircuitBuilder { let gate_type = BaseSumGate::<2>::new_from_config::(&self.config); let gate_index = self.add_gate(gate_type, vec![]); for (limb, wire) in bits - .clone() + .iter() .zip(BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + num_bits) { - self.connect(limb.borrow().target, Target::wire(gate_index, wire)); + self.connect(limb.target, Target::wire(gate_index, wire)); } for l in gate_type.limbs().skip(num_bits) { self.assert_zero(Target::wire(gate_index, l)); @@ -62,7 +66,7 @@ impl, const D: usize> CircuitBuilder { self.add_simple_generator(BaseSumGenerator::<2> { gate_index, - limbs: bits.map(|l| *l.borrow()).collect(), + limbs: bits, }); Target::wire(gate_index, BaseSumGate::<2>::WIRE_SUM) @@ -146,14 +150,14 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let n = thread_rng().gen_range(0..(1 << 10)); + let n = thread_rng().gen_range(0..(1 << 30)); let x = builder.constant(F::from_canonical_usize(n)); let zero = builder._false(); let one = builder._true(); let y = builder.le_sum( - (0..10) + (0..30) .scan(n, |acc, _| { let tmp = *acc % 2; *acc /= 2; diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 39527c6a..72786bd8 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -24,8 +24,7 @@ impl, const D: usize> CircuitBuilder { let mut bits = Vec::with_capacity(num_bits); for &gate in &gates { - let start_limbs = BaseSumGate::<2>::START_LIMBS; - for limb_input in start_limbs..start_limbs + gate_type.num_limbs { + for limb_input in gate_type.limbs() { // `new_unsafe` is safe here because BaseSumGate::<2> forces it to be in `{0, 1}`. bits.push(BoolTarget::new_unsafe(Target::wire(gate, limb_input))); } @@ -35,10 +34,11 @@ impl, const D: usize> CircuitBuilder { } let zero = self.zero(); + let base = F::TWO.exp_u64(gate_type.num_limbs as u64); let mut acc = zero; for &gate in gates.iter().rev() { let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); - acc = self.mul_const_add(F::from_canonical_usize(1 << gate_type.num_limbs), acc, sum); + acc = self.mul_const_add(base, acc, sum); } self.connect(acc, integer); @@ -96,11 +96,18 @@ impl SimpleGenerator for WireSplitGenerator { for &gate in &self.gates { let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); - out_buffer.set_target( - sum, - F::from_canonical_u64(integer_value & ((1 << self.num_limbs) - 1)), - ); - integer_value >>= self.num_limbs; + + // If num_limbs >= 64, we don't need to truncate since `integer_value` is already + // limited to 64 bits, and trying to do so would cause overflow. Hence the conditional. + let mut truncated_value = integer_value; + if self.num_limbs < 64 { + truncated_value = integer_value & ((1 << self.num_limbs) - 1); + integer_value >>= self.num_limbs; + } else { + integer_value = 0; + }; + + out_buffer.set_target(sum, F::from_canonical_u64(truncated_value)); } debug_assert_eq!( diff --git a/src/gates/arithmetic_base.rs b/src/gates/arithmetic_base.rs new file mode 100644 index 00000000..d5c131a5 --- /dev/null +++ b/src/gates/arithmetic_base.rs @@ -0,0 +1,212 @@ +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config +/// supports enough routed wires, it can support several such operations in one gate. +#[derive(Debug)] +pub struct ArithmeticGate { + /// Number of arithmetic operations performed by an arithmetic gate. + pub num_ops: usize, +} + +impl ArithmeticGate { + pub fn new_from_config(config: &CircuitConfig) -> Self { + Self { + num_ops: Self::num_ops(config), + } + } + + /// Determine the maximum number of operations that can fit in one gate for the given config. + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 4; + config.num_routed_wires / wires_per_op + } + + pub fn wire_ith_multiplicand_0(i: usize) -> usize { + 4 * i + } + pub fn wire_ith_multiplicand_1(i: usize) -> usize { + 4 * i + 1 + } + pub fn wire_ith_addend(i: usize) -> usize { + 4 * i + 2 + } + pub fn wire_ith_output(i: usize) -> usize { + 4 * i + 3 + } +} + +impl, const D: usize> Gate for ArithmeticGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[Self::wire_ith_addend(i)]; + let output = vars.local_wires[Self::wire_ith_output(i)]; + let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1; + + constraints.push(output - computed_output); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[Self::wire_ith_addend(i)]; + let output = vars.local_wires[Self::wire_ith_output(i)]; + let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1; + + constraints.push(output - computed_output); + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[Self::wire_ith_addend(i)]; + let output = vars.local_wires[Self::wire_ith_output(i)]; + let computed_output = { + let scaled_mul = + builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]); + builder.mul_add_extension(const_1, addend, scaled_mul) + }; + + let diff = builder.sub_extension(output, computed_output); + constraints.push(diff); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + local_constants: &[F], + ) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + ArithmeticBaseGenerator { + gate_index, + const_0: local_constants[0], + const_1: local_constants[1], + i, + } + .adapter(), + ); + g + }) + .collect::>() + } + + fn num_wires(&self) -> usize { + self.num_ops * 4 + } + + fn num_constants(&self) -> usize { + 2 + } + + fn degree(&self) -> usize { + 3 + } + + fn num_constraints(&self) -> usize { + self.num_ops + } +} + +#[derive(Clone, Debug)] +struct ArithmeticBaseGenerator, const D: usize> { + gate_index: usize, + const_0: F, + const_1: F, + i: usize, +} + +impl, const D: usize> SimpleGenerator + for ArithmeticBaseGenerator +{ + fn dependencies(&self) -> Vec { + [ + ArithmeticGate::wire_ith_multiplicand_0(self.i), + ArithmeticGate::wire_ith_multiplicand_1(self.i), + ArithmeticGate::wire_ith_addend(self.i), + ] + .iter() + .map(|&i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let get_wire = + |wire: usize| -> F { witness.get_target(Target::wire(self.gate_index, wire)) }; + + let multiplicand_0 = get_wire(ArithmeticGate::wire_ith_multiplicand_0(self.i)); + let multiplicand_1 = get_wire(ArithmeticGate::wire_ith_multiplicand_1(self.i)); + let addend = get_wire(ArithmeticGate::wire_ith_addend(self.i)); + + let output_target = Target::wire(self.gate_index, ArithmeticGate::wire_ith_output(self.i)); + + let computed_output = + multiplicand_0 * multiplicand_1 * self.const_0 + addend * self.const_1; + + out_buffer.set_target(output_target, computed_output) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::arithmetic_base::ArithmeticGate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::plonk::circuit_data::CircuitConfig; + + #[test] + fn low_degree() { + let gate = ArithmeticGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_low_degree::(gate); + } + + #[test] + fn eval_fns() -> Result<()> { + let gate = ArithmeticGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_eval_fns::(gate) + } +} diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic_extension.rs similarity index 96% rename from src/gates/arithmetic.rs rename to src/gates/arithmetic_extension.rs index 2898c596..62a89af4 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic_extension.rs @@ -11,7 +11,8 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. +/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config +/// supports enough routed wires, it can support several such operations in one gate. #[derive(Debug)] pub struct ArithmeticExtensionGate { /// Number of arithmetic operations performed by an arithmetic gate. @@ -203,7 +204,7 @@ mod tests { use anyhow::Result; use crate::field::goldilocks_field::GoldilocksField; - use crate::gates::arithmetic::ArithmeticExtensionGate; + use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs index 220c2dd0..472e1592 100644 --- a/src/gates/arithmetic_u32.rs +++ b/src/gates/arithmetic_u32.rs @@ -11,37 +11,49 @@ use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Number of arithmetic operations performed by an arithmetic gate. -pub const NUM_U32_ARITHMETIC_OPS: usize = 3; - /// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand). -#[derive(Debug)] +#[derive(Copy, Clone, Debug)] pub struct U32ArithmeticGate, const D: usize> { + pub num_ops: usize, _phantom: PhantomData, } impl, const D: usize> U32ArithmeticGate { - pub fn wire_ith_multiplicand_0(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn new_from_config(config: &CircuitConfig) -> Self { + Self { + num_ops: Self::num_ops(config), + _phantom: PhantomData, + } + } + + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 5 + Self::num_limbs(); + let routed_wires_per_op = 5; + (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) + } + + pub fn wire_ith_multiplicand_0(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i } - pub fn wire_ith_multiplicand_1(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_multiplicand_1(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 1 } - pub fn wire_ith_addend(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_addend(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 2 } - pub fn wire_ith_output_low_half(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_output_low_half(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 3 } - pub fn wire_ith_output_high_half(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_output_high_half(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 4 } @@ -52,10 +64,10 @@ impl, const D: usize> U32ArithmeticGate { 64 / Self::limb_bits() } - pub fn wire_ith_output_jth_limb(i: usize, j: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); debug_assert!(j < Self::num_limbs()); - 5 * NUM_U32_ARITHMETIC_OPS + Self::num_limbs() * i + j + 5 * self.num_ops + Self::num_limbs() * i + j } } @@ -66,15 +78,15 @@ impl, const D: usize> Gate for U32ArithmeticGate { fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_ARITHMETIC_OPS { - let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[Self::wire_ith_addend(i)]; + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; let computed_output = multiplicand_0 * multiplicand_1 + addend; - let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)]; + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; let base = F::Extension::from_canonical_u64(1 << 32u64); let combined_output = output_high * base + output_low; @@ -86,7 +98,7 @@ impl, const D: usize> Gate for U32ArithmeticGate { let midpoint = Self::num_limbs() / 2; let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let product = (0..max_limb) .map(|x| this_limb - F::Extension::from_canonical_usize(x)) @@ -108,15 +120,15 @@ impl, const D: usize> Gate for U32ArithmeticGate { fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_ARITHMETIC_OPS { - let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[Self::wire_ith_addend(i)]; + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; let computed_output = multiplicand_0 * multiplicand_1 + addend; - let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)]; + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; let base = F::from_canonical_u64(1 << 32u64); let combined_output = output_high * base + output_low; @@ -128,7 +140,7 @@ impl, const D: usize> Gate for U32ArithmeticGate { let midpoint = Self::num_limbs() / 2; let base = F::from_canonical_u64(1u64 << Self::limb_bits()); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let product = (0..max_limb) .map(|x| this_limb - F::from_canonical_usize(x)) @@ -155,15 +167,15 @@ impl, const D: usize> Gate for U32ArithmeticGate { ) -> Vec> { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_ARITHMETIC_OPS { - let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[Self::wire_ith_addend(i)]; + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; let computed_output = builder.mul_add_extension(multiplicand_0, multiplicand_1, addend); - let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)]; + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; let base: F::Extension = F::from_canonical_u64(1 << 32u64).into(); let base_target = builder.constant_extension(base); @@ -177,7 +189,7 @@ impl, const D: usize> Gate for U32ArithmeticGate { let base = builder .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let mut product = builder.one_extension(); @@ -210,10 +222,11 @@ impl, const D: usize> Gate for U32ArithmeticGate { gate_index: usize, _local_constants: &[F], ) -> Vec>> { - (0..NUM_U32_ARITHMETIC_OPS) + (0..self.num_ops) .map(|i| { let g: Box> = Box::new( U32ArithmeticGenerator { + gate: *self, gate_index, i, _phantom: PhantomData, @@ -226,7 +239,7 @@ impl, const D: usize> Gate for U32ArithmeticGate { } fn num_wires(&self) -> usize { - NUM_U32_ARITHMETIC_OPS * (5 + Self::num_limbs()) + self.num_ops * (5 + Self::num_limbs()) } fn num_constants(&self) -> usize { @@ -238,12 +251,13 @@ impl, const D: usize> Gate for U32ArithmeticGate { } fn num_constraints(&self) -> usize { - NUM_U32_ARITHMETIC_OPS * (3 + Self::num_limbs()) + self.num_ops * (3 + Self::num_limbs()) } } #[derive(Clone, Debug)] struct U32ArithmeticGenerator, const D: usize> { + gate: U32ArithmeticGate, gate_index: usize, i: usize, _phantom: PhantomData, @@ -253,17 +267,11 @@ impl, const D: usize> SimpleGenerator for U32ArithmeticGener fn dependencies(&self) -> Vec { let local_target = |input| Target::wire(self.gate_index, input); - let mut deps = Vec::with_capacity(3); - deps.push(local_target( - U32ArithmeticGate::::wire_ith_multiplicand_0(self.i), - )); - deps.push(local_target( - U32ArithmeticGate::::wire_ith_multiplicand_1(self.i), - )); - deps.push(local_target(U32ArithmeticGate::::wire_ith_addend( - self.i, - ))); - deps + vec![ + local_target(self.gate.wire_ith_multiplicand_0(self.i)), + local_target(self.gate.wire_ith_multiplicand_1(self.i)), + local_target(self.gate.wire_ith_addend(self.i)), + ] } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { @@ -274,11 +282,9 @@ impl, const D: usize> SimpleGenerator for U32ArithmeticGener let get_local_wire = |input| witness.get_wire(local_wire(input)); - let multiplicand_0 = - get_local_wire(U32ArithmeticGate::::wire_ith_multiplicand_0(self.i)); - let multiplicand_1 = - get_local_wire(U32ArithmeticGate::::wire_ith_multiplicand_1(self.i)); - let addend = get_local_wire(U32ArithmeticGate::::wire_ith_addend(self.i)); + let multiplicand_0 = get_local_wire(self.gate.wire_ith_multiplicand_0(self.i)); + let multiplicand_1 = get_local_wire(self.gate.wire_ith_multiplicand_1(self.i)); + let addend = get_local_wire(self.gate.wire_ith_addend(self.i)); let output = multiplicand_0 * multiplicand_1 + addend; let mut output_u64 = output.to_canonical_u64(); @@ -289,34 +295,25 @@ impl, const D: usize> SimpleGenerator for U32ArithmeticGener let output_high = F::from_canonical_u64(output_high_u64); let output_low = F::from_canonical_u64(output_low_u64); - let output_high_wire = - local_wire(U32ArithmeticGate::::wire_ith_output_high_half(self.i)); - let output_low_wire = - local_wire(U32ArithmeticGate::::wire_ith_output_low_half(self.i)); + let output_high_wire = local_wire(self.gate.wire_ith_output_high_half(self.i)); + let output_low_wire = local_wire(self.gate.wire_ith_output_low_half(self.i)); out_buffer.set_wire(output_high_wire, output_high); out_buffer.set_wire(output_low_wire, output_low); let num_limbs = U32ArithmeticGate::::num_limbs(); let limb_base = 1 << U32ArithmeticGate::::limb_bits(); - let output_limbs_u64: Vec<_> = unfold((), move |_| { + let output_limbs_u64 = unfold((), move |_| { let ret = output_u64 % limb_base; output_u64 /= limb_base; Some(ret) }) - .take(num_limbs) - .collect(); - let output_limbs_f: Vec<_> = output_limbs_u64 - .iter() - .cloned() - .map(F::from_canonical_u64) - .collect(); + .take(num_limbs); + let output_limbs_f = output_limbs_u64.map(F::from_canonical_u64); - for j in 0..num_limbs { - let wire = local_wire(U32ArithmeticGate::::wire_ith_output_jth_limb( - self.i, j, - )); - out_buffer.set_wire(wire, output_limbs_f[j]); + for (j, output_limb) in output_limbs_f.enumerate() { + let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); + out_buffer.set_wire(wire, output_limb); } } } @@ -330,7 +327,7 @@ mod tests { use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; - use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; + use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::hash::hash_types::HashOut; @@ -340,16 +337,15 @@ mod tests { #[test] fn low_degree() { test_low_degree::(U32ArithmeticGate:: { + num_ops: 3, _phantom: PhantomData, }) } #[test] fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - test_eval_fns::(U32ArithmeticGate:: { + test_eval_fns::(U32ArithmeticGate:: { + num_ops: 3, _phantom: PhantomData, }) } @@ -360,6 +356,7 @@ mod tests { type C = PoseidonGoldilocksConfig; type F = >::F; type FF = >::FE; + const NUM_U32_ARITHMETIC_OPS: usize = 3; fn get_wires( multiplicands_0: Vec, @@ -387,8 +384,7 @@ mod tests { output /= limb_base; } let mut output_limbs_f: Vec<_> = output_limbs - .iter() - .cloned() + .into_iter() .map(F::from_canonical_u64) .collect(); @@ -418,6 +414,7 @@ mod tests { .collect(); let gate = U32ArithmeticGate:: { + num_ops: NUM_U32_ARITHMETIC_OPS, _phantom: PhantomData, }; diff --git a/src/gates/assert_le.rs b/src/gates/assert_le.rs new file mode 100644 index 00000000..4da3c44b --- /dev/null +++ b/src/gates/assert_le.rs @@ -0,0 +1,607 @@ +use std::marker::PhantomData; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, PrimeField, RichField}; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive}; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::util::{bits_u64, ceil_div_usize}; + +// TODO: replace/merge this gate with `ComparisonGate`. + +/// A gate for checking that one value is less than or equal to another. +#[derive(Clone, Debug)] +pub struct AssertLessThanGate, const D: usize> { + pub(crate) num_bits: usize, + pub(crate) num_chunks: usize, + _phantom: PhantomData, +} + +impl, const D: usize> AssertLessThanGate { + pub fn new(num_bits: usize, num_chunks: usize) -> Self { + debug_assert!(num_bits < bits_u64(F::ORDER)); + Self { + num_bits, + num_chunks, + _phantom: PhantomData, + } + } + + pub fn chunk_bits(&self) -> usize { + ceil_div_usize(self.num_bits, self.num_chunks) + } + + pub fn wire_first_input(&self) -> usize { + 0 + } + + pub fn wire_second_input(&self) -> usize { + 1 + } + + pub fn wire_most_significant_diff(&self) -> usize { + 2 + } + + pub fn wire_first_chunk_val(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + chunk + } + + pub fn wire_second_chunk_val(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + self.num_chunks + chunk + } + + pub fn wire_equality_dummy(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + 2 * self.num_chunks + chunk + } + + pub fn wire_chunks_equal(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + 3 * self.num_chunks + chunk + } + + pub fn wire_intermediate_value(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + 4 * self.num_chunks + chunk + } +} + +impl, const D: usize> Gate for AssertLessThanGate { + fn id(&self) -> String { + format!("{:?}", self, D) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let first_chunks_combined = reduce_with_powers( + &first_chunks, + F::Extension::from_canonical_usize(1 << self.chunk_bits()), + ); + let second_chunks_combined = reduce_with_powers( + &second_chunks, + F::Extension::from_canonical_usize(1 << self.chunk_bits()), + ); + + constraints.push(first_chunks_combined - first_input); + constraints.push(second_chunks_combined - second_input); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = F::Extension::ZERO; + + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let first_product = (0..chunk_size) + .map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x)) + .product(); + let second_product = (0..chunk_size) + .map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(first_product); + constraints.push(second_product); + + let difference = second_chunks[i] - first_chunks[i]; + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal)); + constraints.push(chunks_equal * difference); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); + most_significant_diff_so_far = + intermediate_value + (F::Extension::ONE - chunks_equal) * difference; + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints.push(most_significant_diff - most_significant_diff_so_far); + + // Range check `most_significant_diff` to be less than `chunk_size`. + let product = (0..chunk_size) + .map(|x| most_significant_diff - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(product); + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let first_chunks_combined = reduce_with_powers( + &first_chunks, + F::from_canonical_usize(1 << self.chunk_bits()), + ); + let second_chunks_combined = reduce_with_powers( + &second_chunks, + F::from_canonical_usize(1 << self.chunk_bits()), + ); + + constraints.push(first_chunks_combined - first_input); + constraints.push(second_chunks_combined - second_input); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = F::ZERO; + + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let first_product = (0..chunk_size) + .map(|x| first_chunks[i] - F::from_canonical_usize(x)) + .product(); + let second_product = (0..chunk_size) + .map(|x| second_chunks[i] - F::from_canonical_usize(x)) + .product(); + constraints.push(first_product); + constraints.push(second_product); + + let difference = second_chunks[i] - first_chunks[i]; + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + constraints.push(difference * equality_dummy - (F::ONE - chunks_equal)); + constraints.push(chunks_equal * difference); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); + most_significant_diff_so_far = + intermediate_value + (F::ONE - chunks_equal) * difference; + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints.push(most_significant_diff - most_significant_diff_so_far); + + // Range check `most_significant_diff` to be less than `chunk_size`. + let product = (0..chunk_size) + .map(|x| most_significant_diff - F::from_canonical_usize(x)) + .product(); + constraints.push(product); + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits())); + let first_chunks_combined = + reduce_with_powers_ext_recursive(builder, &first_chunks, chunk_base); + let second_chunks_combined = + reduce_with_powers_ext_recursive(builder, &second_chunks, chunk_base); + + constraints.push(builder.sub_extension(first_chunks_combined, first_input)); + constraints.push(builder.sub_extension(second_chunks_combined, second_input)); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = builder.zero_extension(); + + let one = builder.one_extension(); + // Find the chosen chunk. + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let mut first_product = one; + let mut second_product = one; + for x in 0..chunk_size { + let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); + let first_diff = builder.sub_extension(first_chunks[i], x_f); + let second_diff = builder.sub_extension(second_chunks[i], x_f); + first_product = builder.mul_extension(first_product, first_diff); + second_product = builder.mul_extension(second_product, second_diff); + } + constraints.push(first_product); + constraints.push(second_product); + + let difference = builder.sub_extension(second_chunks[i], first_chunks[i]); + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + let diff_times_equal = builder.mul_extension(difference, equality_dummy); + let not_equal = builder.sub_extension(one, chunks_equal); + constraints.push(builder.sub_extension(diff_times_equal, not_equal)); + constraints.push(builder.mul_extension(chunks_equal, difference)); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far); + constraints.push(builder.sub_extension(intermediate_value, old_diff)); + + let not_equal = builder.sub_extension(one, chunks_equal); + let new_diff = builder.mul_extension(not_equal, difference); + most_significant_diff_so_far = builder.add_extension(intermediate_value, new_diff); + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints + .push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far)); + + // Range check `most_significant_diff` to be less than `chunk_size`. + let mut product = builder.one_extension(); + for x in 0..chunk_size { + let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); + let diff = builder.sub_extension(most_significant_diff, x_f); + product = builder.mul_extension(product, diff); + } + constraints.push(product); + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + let gen = AssertLessThanGenerator:: { + gate_index, + gate: self.clone(), + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.wire_intermediate_value(self.num_chunks - 1) + 1 + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 << self.chunk_bits() + } + + fn num_constraints(&self) -> usize { + 4 + 5 * self.num_chunks + } +} + +#[derive(Debug)] +struct AssertLessThanGenerator, const D: usize> { + gate_index: usize, + gate: AssertLessThanGate, +} + +impl, const D: usize> SimpleGenerator + for AssertLessThanGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |input| Target::wire(self.gate_index, input); + + vec![ + local_target(self.gate.wire_first_input()), + local_target(self.gate.wire_second_input()), + ] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let get_local_wire = |input| witness.get_wire(local_wire(input)); + + let first_input = get_local_wire(self.gate.wire_first_input()); + let second_input = get_local_wire(self.gate.wire_second_input()); + + let first_input_u64 = first_input.to_canonical_u64(); + let second_input_u64 = second_input.to_canonical_u64(); + + debug_assert!(first_input_u64 < second_input_u64); + + let chunk_size = 1 << self.gate.chunk_bits(); + let first_input_chunks: Vec = (0..self.gate.num_chunks) + .scan(first_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + let second_input_chunks: Vec = (0..self.gate.num_chunks) + .scan(second_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + let chunks_equal: Vec = (0..self.gate.num_chunks) + .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) + .collect(); + let equality_dummies: Vec = first_input_chunks + .iter() + .zip(second_input_chunks.iter()) + .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) + .collect(); + + let mut most_significant_diff_so_far = F::ZERO; + let mut intermediate_values = Vec::new(); + for i in 0..self.gate.num_chunks { + if first_input_chunks[i] != second_input_chunks[i] { + most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; + intermediate_values.push(F::ZERO); + } else { + intermediate_values.push(most_significant_diff_so_far); + } + } + let most_significant_diff = most_significant_diff_so_far; + + out_buffer.set_wire( + local_wire(self.gate.wire_most_significant_diff()), + most_significant_diff, + ); + for i in 0..self.gate.num_chunks { + out_buffer.set_wire( + local_wire(self.gate.wire_first_chunk_val(i)), + first_input_chunks[i], + ); + out_buffer.set_wire( + local_wire(self.gate.wire_second_chunk_val(i)), + second_input_chunks[i], + ); + out_buffer.set_wire( + local_wire(self.gate.wire_equality_dummy(i)), + equality_dummies[i], + ); + out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]); + out_buffer.set_wire( + local_wire(self.gate.wire_intermediate_value(i)), + intermediate_values[i], + ); + } + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use rand::Rng; + + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::{Field, PrimeField}; + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::assert_le::AssertLessThanGate; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::hash::hash_types::HashOut; + use crate::plonk::vars::EvaluationVars; + + #[test] + fn wire_indices() { + type AG = AssertLessThanGate; + let num_bits = 40; + let num_chunks = 5; + + let gate = AG { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + + assert_eq!(gate.wire_first_input(), 0); + assert_eq!(gate.wire_second_input(), 1); + assert_eq!(gate.wire_most_significant_diff(), 2); + assert_eq!(gate.wire_first_chunk_val(0), 3); + assert_eq!(gate.wire_first_chunk_val(4), 7); + assert_eq!(gate.wire_second_chunk_val(0), 8); + assert_eq!(gate.wire_second_chunk_val(4), 12); + assert_eq!(gate.wire_equality_dummy(0), 13); + assert_eq!(gate.wire_equality_dummy(4), 17); + assert_eq!(gate.wire_chunks_equal(0), 18); + assert_eq!(gate.wire_chunks_equal(4), 22); + assert_eq!(gate.wire_intermediate_value(0), 23); + assert_eq!(gate.wire_intermediate_value(4), 27); + } + + #[test] + fn low_degree() { + let num_bits = 20; + let num_chunks = 4; + + test_low_degree::(AssertLessThanGate::<_, 4>::new( + num_bits, num_chunks, + )) + } + + #[test] + fn eval_fns() -> Result<()> { + let num_bits = 20; + let num_chunks = 4; + + test_eval_fns::(AssertLessThanGate::<_, 4>::new( + num_bits, num_chunks, + )) + } + + #[test] + fn test_gate_constraint() { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + + let num_bits = 40; + let num_chunks = 5; + let chunk_bits = num_bits / num_chunks; + + // Returns the local wires for an AssertLessThanGate given the two inputs. + let get_wires = |first_input: F, second_input: F| -> Vec { + let mut v = Vec::new(); + + let first_input_u64 = first_input.to_canonical_u64(); + let second_input_u64 = second_input.to_canonical_u64(); + + let chunk_size = 1 << chunk_bits; + let mut first_input_chunks: Vec = (0..num_chunks) + .scan(first_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + let mut second_input_chunks: Vec = (0..num_chunks) + .scan(second_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + let mut chunks_equal: Vec = (0..num_chunks) + .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) + .collect(); + let mut equality_dummies: Vec = first_input_chunks + .iter() + .zip(second_input_chunks.iter()) + .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) + .collect(); + + let mut most_significant_diff_so_far = F::ZERO; + let mut intermediate_values = Vec::new(); + for i in 0..num_chunks { + if first_input_chunks[i] != second_input_chunks[i] { + most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; + intermediate_values.push(F::ZERO); + } else { + intermediate_values.push(most_significant_diff_so_far); + } + } + let most_significant_diff = most_significant_diff_so_far; + + v.push(first_input); + v.push(second_input); + v.push(most_significant_diff); + v.append(&mut first_input_chunks); + v.append(&mut second_input_chunks); + v.append(&mut equality_dummies); + v.append(&mut chunks_equal); + v.append(&mut intermediate_values); + + v.iter().map(|&x| x.into()).collect::>() + }; + + let mut rng = rand::thread_rng(); + let max: u64 = 1 << (num_bits - 1); + let first_input_u64 = rng.gen_range(0..max); + let second_input_u64 = { + let mut val = rng.gen_range(0..max); + while val < first_input_u64 { + val = rng.gen_range(0..max); + } + val + }; + + let first_input = F::from_canonical_u64(first_input_u64); + let second_input = F::from_canonical_u64(second_input_u64); + + let less_than_gate = AssertLessThanGate:: { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + let less_than_vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(first_input, second_input), + public_inputs_hash: &HashOut::rand(), + }; + assert!( + less_than_gate + .eval_unfiltered(less_than_vars) + .iter() + .all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + + let equal_gate = AssertLessThanGate:: { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + let equal_vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(first_input, first_input), + public_inputs_hash: &HashOut::rand(), + }; + assert!( + equal_gate + .eval_unfiltered(equal_vars) + .iter() + .all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index 1a3dac47..cc4886ea 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -24,8 +24,7 @@ impl BaseSumGate { } pub fn new_from_config(config: &CircuitConfig) -> Self { - let num_limbs = ((F::ORDER as f64).log(B as f64).floor() as usize) - .min(config.num_routed_wires - Self::START_LIMBS); + let num_limbs = F::BITS.min(config.num_routed_wires - Self::START_LIMBS); Self::new(num_limbs) } diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index bb2d813c..5d1fcf4f 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -43,33 +43,42 @@ impl, const D: usize> ComparisonGate { 1 } - pub fn wire_most_significant_diff(&self) -> usize { + pub fn wire_result_bool(&self) -> usize { 2 } + pub fn wire_most_significant_diff(&self) -> usize { + 3 + } + pub fn wire_first_chunk_val(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + chunk + 4 + chunk } pub fn wire_second_chunk_val(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + self.num_chunks + chunk + 4 + self.num_chunks + chunk } pub fn wire_equality_dummy(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + 2 * self.num_chunks + chunk + 4 + 2 * self.num_chunks + chunk } pub fn wire_chunks_equal(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + 3 * self.num_chunks + chunk + 4 + 3 * self.num_chunks + chunk } pub fn wire_intermediate_value(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + 4 * self.num_chunks + chunk + 4 + 4 * self.num_chunks + chunk + } + + /// The `bit_index`th bit of 2^n - 1 + most_significant_diff. + pub fn wire_most_significant_diff_bit(&self, bit_index: usize) -> usize { + 4 + 5 * self.num_chunks + bit_index } } @@ -110,10 +119,10 @@ impl, const D: usize> Gate for ComparisonGate { for i in 0..self.num_chunks { // Range-check the chunks to be less than `chunk_size`. - let first_product = (0..chunk_size) + let first_product: F::Extension = (0..chunk_size) .map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x)) .product(); - let second_product = (0..chunk_size) + let second_product: F::Extension = (0..chunk_size) .map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x)) .product(); constraints.push(first_product); @@ -137,11 +146,22 @@ impl, const D: usize> Gate for ComparisonGate { let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; constraints.push(most_significant_diff - most_significant_diff_so_far); - // Range check `most_significant_diff` to be less than `chunk_size`. - let product = (0..chunk_size) - .map(|x| most_significant_diff - F::Extension::from_canonical_usize(x)) - .product(); - constraints.push(product); + let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + + // Range-check the bits. + for &bit in &most_significant_diff_bits { + constraints.push(bit * (F::Extension::ONE - bit)); + } + + let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO); + let two_n = F::Extension::from_canonical_u64(1 << self.chunk_bits()); + constraints.push((two_n + most_significant_diff) - bits_combined); + + // Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); constraints } @@ -178,10 +198,10 @@ impl, const D: usize> Gate for ComparisonGate { for i in 0..self.num_chunks { // Range-check the chunks to be less than `chunk_size`. - let first_product = (0..chunk_size) + let first_product: F = (0..chunk_size) .map(|x| first_chunks[i] - F::from_canonical_usize(x)) .product(); - let second_product = (0..chunk_size) + let second_product: F = (0..chunk_size) .map(|x| second_chunks[i] - F::from_canonical_usize(x)) .product(); constraints.push(first_product); @@ -205,11 +225,22 @@ impl, const D: usize> Gate for ComparisonGate { let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; constraints.push(most_significant_diff - most_significant_diff_so_far); - // Range check `most_significant_diff` to be less than `chunk_size`. - let product = (0..chunk_size) - .map(|x| most_significant_diff - F::from_canonical_usize(x)) - .product(); - constraints.push(product); + let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + + // Range-check the bits. + for &bit in &most_significant_diff_bits { + constraints.push(bit * (F::ONE - bit)); + } + + let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO); + let two_n = F::from_canonical_u64(1 << self.chunk_bits()); + constraints.push((two_n + most_significant_diff) - bits_combined); + + // Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); constraints } @@ -285,14 +316,29 @@ impl, const D: usize> Gate for ComparisonGate { constraints .push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far)); - // Range check `most_significant_diff` to be less than `chunk_size`. - let mut product = builder.one_extension(); - for x in 0..chunk_size { - let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); - let diff = builder.sub_extension(most_significant_diff, x_f); - product = builder.mul_extension(product, diff); + let most_significant_diff_bits: Vec> = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + + // Range-check the bits. + for &this_bit in &most_significant_diff_bits { + let inverse = builder.sub_extension(one, this_bit); + constraints.push(builder.mul_extension(this_bit, inverse)); } - constraints.push(product); + + let two = builder.two(); + let bits_combined = + reduce_with_powers_ext_recursive(builder, &most_significant_diff_bits, two); + let two_n = + builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits())); + let sum = builder.add_extension(two_n, most_significant_diff); + constraints.push(builder.sub_extension(sum, bits_combined)); + + // Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + constraints.push( + builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()]), + ); constraints } @@ -310,7 +356,7 @@ impl, const D: usize> Gate for ComparisonGate { } fn num_wires(&self) -> usize { - self.wire_intermediate_value(self.num_chunks - 1) + 1 + 4 + 5 * self.num_chunks + (self.chunk_bits() + 1) } fn num_constants(&self) -> usize { @@ -322,7 +368,7 @@ impl, const D: usize> Gate for ComparisonGate { } fn num_constraints(&self) -> usize { - 4 + 5 * self.num_chunks + 6 + 5 * self.num_chunks + self.chunk_bits() } } @@ -336,10 +382,10 @@ impl, const D: usize> SimpleGenerator for ComparisonGenerato fn dependencies(&self) -> Vec { let local_target = |input| Target::wire(self.gate_index, input); - let mut deps = Vec::new(); - deps.push(local_target(self.gate.wire_first_input())); - deps.push(local_target(self.gate.wire_second_input())); - deps + vec![ + local_target(self.gate.wire_first_input()), + local_target(self.gate.wire_second_input()), + ] } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { @@ -356,7 +402,7 @@ impl, const D: usize> SimpleGenerator for ComparisonGenerato let first_input_u64 = first_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64(); - debug_assert!(first_input_u64 < second_input_u64); + let result = F::from_canonical_usize((first_input_u64 <= second_input_u64) as usize); let chunk_size = 1 << self.gate.chunk_bits(); let first_input_chunks: Vec = (0..self.gate.num_chunks) @@ -395,6 +441,22 @@ impl, const D: usize> SimpleGenerator for ComparisonGenerato } let most_significant_diff = most_significant_diff_so_far; + let two_n = F::from_canonical_usize(1 << self.gate.chunk_bits()); + let two_n_plus_msd = (two_n + most_significant_diff).to_canonical_u64(); + + let msd_bits_u64: Vec = (0..self.gate.chunk_bits() + 1) + .scan(two_n_plus_msd, |acc, _| { + let tmp = *acc % 2; + *acc /= 2; + Some(tmp) + }) + .collect(); + let msd_bits: Vec = msd_bits_u64 + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect(); + + out_buffer.set_wire(local_wire(self.gate.wire_result_bool()), result); out_buffer.set_wire( local_wire(self.gate.wire_most_significant_diff()), most_significant_diff, @@ -418,6 +480,12 @@ impl, const D: usize> SimpleGenerator for ComparisonGenerato intermediate_values[i], ); } + for i in 0..self.gate.chunk_bits() + 1 { + out_buffer.set_wire( + local_wire(self.gate.wire_most_significant_diff_bit(i)), + msd_bits[i], + ); + } } } @@ -451,17 +519,20 @@ mod tests { assert_eq!(gate.wire_first_input(), 0); assert_eq!(gate.wire_second_input(), 1); - assert_eq!(gate.wire_most_significant_diff(), 2); - assert_eq!(gate.wire_first_chunk_val(0), 3); - assert_eq!(gate.wire_first_chunk_val(4), 7); - assert_eq!(gate.wire_second_chunk_val(0), 8); - assert_eq!(gate.wire_second_chunk_val(4), 12); - assert_eq!(gate.wire_equality_dummy(0), 13); - assert_eq!(gate.wire_equality_dummy(4), 17); - assert_eq!(gate.wire_chunks_equal(0), 18); - assert_eq!(gate.wire_chunks_equal(4), 22); - assert_eq!(gate.wire_intermediate_value(0), 23); - assert_eq!(gate.wire_intermediate_value(4), 27); + assert_eq!(gate.wire_result_bool(), 2); + assert_eq!(gate.wire_most_significant_diff(), 3); + assert_eq!(gate.wire_first_chunk_val(0), 4); + assert_eq!(gate.wire_first_chunk_val(4), 8); + assert_eq!(gate.wire_second_chunk_val(0), 9); + assert_eq!(gate.wire_second_chunk_val(4), 13); + assert_eq!(gate.wire_equality_dummy(0), 14); + assert_eq!(gate.wire_equality_dummy(4), 18); + assert_eq!(gate.wire_chunks_equal(0), 19); + assert_eq!(gate.wire_chunks_equal(4), 23); + assert_eq!(gate.wire_intermediate_value(0), 24); + assert_eq!(gate.wire_intermediate_value(4), 28); + assert_eq!(gate.wire_most_significant_diff_bit(0), 29); + assert_eq!(gate.wire_most_significant_diff_bit(8), 37); } #[test] @@ -501,6 +572,8 @@ mod tests { let first_input_u64 = first_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64(); + let result_bool = F::from_bool(first_input_u64 <= second_input_u64); + let chunk_size = 1 << chunk_bits; let mut first_input_chunks: Vec = (0..num_chunks) .scan(first_input_u64, |acc, _| { @@ -538,20 +611,32 @@ mod tests { } let most_significant_diff = most_significant_diff_so_far; + let two_n_plus_msd = + (1 << chunk_bits) as u64 + most_significant_diff.to_canonical_u64(); + let mut msd_bits: Vec = (0..chunk_bits + 1) + .scan(two_n_plus_msd, |acc, _| { + let tmp = *acc % 2; + *acc /= 2; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + v.push(first_input); v.push(second_input); + v.push(result_bool); v.push(most_significant_diff); v.append(&mut first_input_chunks); v.append(&mut second_input_chunks); v.append(&mut equality_dummies); v.append(&mut chunks_equal); v.append(&mut intermediate_values); + v.append(&mut msd_bits); v.iter().map(|&x| x.into()).collect::>() }; let mut rng = rand::thread_rng(); - let max: u64 = 1 << num_bits - 1; + let max: u64 = 1 << (num_bits - 1); let first_input_u64 = rng.gen_range(0..max); let second_input_u64 = { let mut val = rng.gen_range(0..max); diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index 8bb3e593..cc2f970d 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -337,9 +337,8 @@ mod tests { .map(|b| F::from_canonical_u64(*b)) .collect(); - let mut v = Vec::new(); - v.push(base); - v.extend(power_bits_f.clone()); + let mut v = vec![base]; + v.extend(power_bits_f); let mut intermediate_values = Vec::new(); let mut current_intermediate_value = F::ONE; diff --git a/src/gates/gate_testing.rs b/src/gates/gate_testing.rs index 019d0204..97a69fdf 100644 --- a/src/gates/gate_testing.rs +++ b/src/gates/gate_testing.rs @@ -10,7 +10,7 @@ use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::GenericConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::plonk::verifier::verify; -use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::{log2_ceil, transpose}; const WITNESS_SIZE: usize = 1 << 5; diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index 11130c43..2d203ed0 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -1,4 +1,4 @@ -use log::info; +use log::debug; use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; @@ -86,7 +86,7 @@ impl, const D: usize> Tree> { } } } - info!( + debug!( "Found tree with max degree {} and {} constants wires in {:.4}s.", best_degree, best_num_constants, @@ -221,12 +221,17 @@ impl, const D: usize> Tree> { #[cfg(test)] mod tests { + use log::info; + use super::*; + use crate::field::goldilocks_field::GoldilocksField; + use crate::gadgets::interpolation::InterpolationGate; + use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::base_sum::BaseSumGate; use crate::gates::constant::ConstantGate; use crate::gates::gmimc::GMiMCGate; - use crate::gates::interpolation::InterpolationGate; + use crate::gates::interpolation::HighDegreeInterpolationGate; use crate::gates::noop::NoopGate; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; @@ -243,7 +248,7 @@ mod tests { GateRef::new(ArithmeticExtensionGate { num_ops: 4 }), GateRef::new(BaseSumGate::<4>::new(4)), GateRef::new(GMiMCGate::::new()), - GateRef::new(InterpolationGate::new(2)), + GateRef::new(HighDegreeInterpolationGate::new(2)), ]; let (tree, _, _) = Tree::from_gates(gates.clone()); diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index a88704c1..8a34943d 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -318,8 +318,6 @@ impl + GMiMC, const D: usize, const WIDTH: usize> Simple #[cfg(test)] mod tests { - use std::convert::TryInto; - use anyhow::Result; use crate::field::field_types::Field; diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index dcb7eb5d..73f96c0b 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::marker::PhantomData; use std::ops::Range; @@ -252,8 +251,7 @@ impl, const D: usize> SimpleGenerator for Insert let local_targets = |inputs: Range| inputs.map(local_target); - let mut deps = Vec::new(); - deps.push(local_target(self.gate.wires_insertion_index())); + let mut deps = vec![local_target(self.gate.wires_insertion_index())]; deps.extend(local_targets(self.gate.wires_element_to_insert())); for i in 0..self.gate.vec_size { deps.extend(local_targets(self.gate.wires_original_list_item(i))); @@ -292,7 +290,7 @@ impl, const D: usize> SimpleGenerator for Insert vec_size ); - let mut new_vec = orig_vec.clone(); + let mut new_vec = orig_vec; new_vec.insert(insertion_index, to_insert); let mut equality_dummy_vals = Vec::new(); @@ -377,14 +375,13 @@ mod tests { fn get_wires(orig_vec: Vec, insertion_index: usize, element_to_insert: FF) -> Vec { let vec_size = orig_vec.len(); - let mut v = Vec::new(); - v.push(F::from_canonical_usize(insertion_index)); + let mut v = vec![F::from_canonical_usize(insertion_index)]; v.extend(element_to_insert.0); for j in 0..vec_size { v.extend(orig_vec[j].0); } - let mut new_vec = orig_vec.clone(); + let mut new_vec = orig_vec; new_vec.insert(insertion_index, element_to_insert); let mut equality_dummy_vals = Vec::new(); for i in 0..=vec_size { diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 01fff1ba..aa39ae15 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::marker::PhantomData; use std::ops::Range; @@ -6,6 +5,7 @@ use crate::field::extension_field::algebra::PolynomialCoeffsAlgebra; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::interpolation::interpolant; +use crate::gadgets::interpolation::InterpolationGate; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::Gate; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; @@ -14,19 +14,20 @@ use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -use crate::polynomial::polynomial::PolynomialCoeffs; +use crate::polynomial::PolynomialCoeffs; -/// Interpolates a polynomial, whose points are a (base field) coset of the multiplicative subgroup -/// with the given size, and whose values are extension field elements, given by input wires. -/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. -#[derive(Clone, Debug)] -pub(crate) struct InterpolationGate, const D: usize> { +/// Interpolation gate with constraints of degree at most `1<, const D: usize> { pub subgroup_bits: usize, _phantom: PhantomData, } -impl, const D: usize> InterpolationGate { - pub fn new(subgroup_bits: usize) -> Self { +impl, const D: usize> InterpolationGate + for HighDegreeInterpolationGate +{ + fn new(subgroup_bits: usize) -> Self { Self { subgroup_bits, _phantom: PhantomData, @@ -36,60 +37,9 @@ impl, const D: usize> InterpolationGate { fn num_points(&self) -> usize { 1 << self.subgroup_bits } +} - /// Wire index of the coset shift. - pub fn wire_shift(&self) -> usize { - 0 - } - - fn start_values(&self) -> usize { - 1 - } - - /// Wire indices of the `i`th interpolant value. - pub fn wires_value(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_values() + i * D; - start..start + D - } - - fn start_evaluation_point(&self) -> usize { - self.start_values() + self.num_points() * D - } - - /// Wire indices of the point to evaluate the interpolant at. - pub fn wires_evaluation_point(&self) -> Range { - let start = self.start_evaluation_point(); - start..start + D - } - - fn start_evaluation_value(&self) -> usize { - self.start_evaluation_point() + D - } - - /// Wire indices of the interpolated value. - pub fn wires_evaluation_value(&self) -> Range { - let start = self.start_evaluation_value(); - start..start + D - } - - fn start_coeffs(&self) -> usize { - self.start_evaluation_value() + D - } - - /// The number of routed wires required in the typical usage of this gate, where the points to - /// interpolate, the evaluation point, and the corresponding value are all routed. - pub(crate) fn num_routed_wires(&self) -> usize { - self.start_coeffs() - } - - /// Wire indices of the interpolant's `i`th coefficient. - pub fn wires_coeff(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_coeffs() + i * D; - start..start + D - } - +impl, const D: usize> HighDegreeInterpolationGate { /// End of wire indices, exclusive. fn end(&self) -> usize { self.start_coeffs() + self.num_points() * D @@ -121,14 +71,16 @@ impl, const D: usize> InterpolationGate { g.powers() .take(size) .map(move |x| { - let subgroup_element = builder.constant(x.into()); + let subgroup_element = builder.constant(x); builder.scalar_mul_ext(subgroup_element, shift) }) .collect() } } -impl, const D: usize> Gate for InterpolationGate { +impl, const D: usize> Gate + for HighDegreeInterpolationGate +{ fn id(&self) -> String { format!("{:?}", self, D) } @@ -221,7 +173,7 @@ impl, const D: usize> Gate for InterpolationGate { ) -> Vec>> { let gen = InterpolationGenerator:: { gate_index, - gate: self.clone(), + gate: *self, _phantom: PhantomData, }; vec![Box::new(gen.adapter())] @@ -251,7 +203,7 @@ impl, const D: usize> Gate for InterpolationGate { #[derive(Debug)] struct InterpolationGenerator, const D: usize> { gate_index: usize, - gate: InterpolationGate, + gate: HighDegreeInterpolationGate, _phantom: PhantomData, } @@ -321,17 +273,18 @@ mod tests { use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; + use crate::gadgets::interpolation::InterpolationGate; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; - use crate::gates::interpolation::InterpolationGate; + use crate::gates::interpolation::HighDegreeInterpolationGate; use crate::hash::hash_types::HashOut; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::vars::EvaluationVars; - use crate::polynomial::polynomial::PolynomialCoeffs; + use crate::polynomial::PolynomialCoeffs; #[test] fn wire_indices() { - let gate = InterpolationGate:: { + let gate = HighDegreeInterpolationGate:: { subgroup_bits: 1, _phantom: PhantomData, }; @@ -350,7 +303,7 @@ mod tests { #[test] fn low_degree() { - test_low_degree::(InterpolationGate::new(2)); + test_low_degree::(HighDegreeInterpolationGate::new(2)); } #[test] @@ -358,7 +311,7 @@ mod tests { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; - test_eval_fns::(InterpolationGate::new(2)) + test_eval_fns::(HighDegreeInterpolationGate::new(2)) } #[test] @@ -370,7 +323,7 @@ mod tests { /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. fn get_wires( - gate: &InterpolationGate, + gate: &HighDegreeInterpolationGate, shift: F, coeffs: PolynomialCoeffs, eval_point: FF, @@ -392,7 +345,7 @@ mod tests { let shift = F::rand(); let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]); let eval_point = FF::rand(); - let gate = InterpolationGate::::new(1); + let gate = HighDegreeInterpolationGate::::new(1); let vars = EvaluationVars { local_constants: &[], local_wires: &get_wires(&gate, shift, coeffs, eval_point), diff --git a/src/gates/low_degree_interpolation.rs b/src/gates/low_degree_interpolation.rs new file mode 100644 index 00000000..709c5e9a --- /dev/null +++ b/src/gates/low_degree_interpolation.rs @@ -0,0 +1,459 @@ +use std::marker::PhantomData; +use std::ops::Range; + +use crate::field::extension_field::algebra::PolynomialCoeffsAlgebra; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::{Extendable, FieldExtension}; +use crate::field::field_types::{Field, RichField}; +use crate::field::interpolation::interpolant; +use crate::gadgets::interpolation::InterpolationGate; +use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::polynomial::PolynomialCoeffs; + +/// Interpolation gate with constraints of degree 2. +/// `eval_unfiltered_recursively` uses more gates than `HighDegreeInterpolationGate`. +#[derive(Copy, Clone, Debug)] +pub(crate) struct LowDegreeInterpolationGate, const D: usize> { + pub subgroup_bits: usize, + _phantom: PhantomData, +} + +impl, const D: usize> InterpolationGate + for LowDegreeInterpolationGate +{ + fn new(subgroup_bits: usize) -> Self { + Self { + subgroup_bits, + _phantom: PhantomData, + } + } + + fn num_points(&self) -> usize { + 1 << self.subgroup_bits + } +} + +impl, const D: usize> LowDegreeInterpolationGate { + /// `powers_shift(i)` is the wire index of `wire_shift^i`. + pub fn powers_shift(&self, i: usize) -> usize { + debug_assert!(0 < i && i < self.num_points()); + if i == 1 { + return self.wire_shift(); + } + self.end_coeffs() + i - 2 + } + + /// `powers_evalutation_point(i)` is the wire index of `evalutation_point^i`. + pub fn powers_evaluation_point(&self, i: usize) -> Range { + debug_assert!(0 < i && i < self.num_points()); + if i == 1 { + return self.wires_evaluation_point(); + } + let start = self.end_coeffs() + self.num_points() - 2 + (i - 2) * D; + start..start + D + } + + /// End of wire indices, exclusive. + fn end(&self) -> usize { + self.powers_evaluation_point(self.num_points() - 1).end + } + + /// The domain of the points we're interpolating. + fn coset(&self, shift: F) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. + g.powers().take(size).map(move |x| x * shift) + } +} + +impl, const D: usize> Gate for LowDegreeInterpolationGate { + fn id(&self) -> String { + format!("{:?}", self, D) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let coeffs = (0..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) + .collect::>(); + + let mut powers_shift = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_shift(i)]) + .collect::>(); + let shift = powers_shift[0]; + for i in 1..self.num_points() - 1 { + constraints.push(powers_shift[i - 1] * shift - powers_shift[i]); + } + powers_shift.insert(0, F::Extension::ONE); + // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. + // Then, `altered(w^i) = original(shift*w^i)`. + let altered_coeffs = coeffs + .iter() + .zip(powers_shift) + .map(|(&c, p)| c.scalar_mul(p)) + .collect::>(); + let interpolant = PolynomialCoeffsAlgebra::new(coeffs); + let altered_interpolant = PolynomialCoeffsAlgebra::new(altered_coeffs); + + for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) + .into_iter() + .enumerate() + { + let value = vars.get_local_ext_algebra(self.wires_value(i)); + let computed_value = altered_interpolant.eval_base(point); + constraints.extend(&(value - computed_value).to_basefield_array()); + } + + let evaluation_point_powers = (1..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i))) + .collect::>(); + let evaluation_point = evaluation_point_powers[0]; + for i in 1..self.num_points() - 1 { + constraints.extend( + (evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i]) + .to_basefield_array(), + ); + } + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); + constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let coeffs = (0..self.num_points()) + .map(|i| vars.get_local_ext(self.wires_coeff(i))) + .collect::>(); + + let mut powers_shift = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_shift(i)]) + .collect::>(); + let shift = powers_shift[0]; + for i in 1..self.num_points() - 1 { + constraints.push(powers_shift[i - 1] * shift - powers_shift[i]); + } + powers_shift.insert(0, F::ONE); + // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. + // Then, `altered(w^i) = original(shift*w^i)`. + let altered_coeffs = coeffs + .iter() + .zip(powers_shift) + .map(|(&c, p)| c.scalar_mul(p)) + .collect::>(); + let interpolant = PolynomialCoeffs::new(coeffs); + let altered_interpolant = PolynomialCoeffs::new(altered_coeffs); + + for (i, point) in F::two_adic_subgroup(self.subgroup_bits) + .into_iter() + .enumerate() + { + let value = vars.get_local_ext(self.wires_value(i)); + let computed_value = altered_interpolant.eval_base(point); + constraints.extend(&(value - computed_value).to_basefield_array()); + } + + let evaluation_point_powers = (1..self.num_points()) + .map(|i| vars.get_local_ext(self.powers_evaluation_point(i))) + .collect::>(); + let evaluation_point = evaluation_point_powers[0]; + for i in 1..self.num_points() - 1 { + constraints.extend( + (evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i]) + .to_basefield_array(), + ); + } + let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); + let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); + constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let coeffs = (0..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) + .collect::>(); + + let mut powers_shift = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_shift(i)]) + .collect::>(); + let shift = powers_shift[0]; + for i in 1..self.num_points() - 1 { + constraints.push(builder.mul_sub_extension( + powers_shift[i - 1], + shift, + powers_shift[i], + )); + } + powers_shift.insert(0, builder.one_extension()); + // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. + // Then, `altered(w^i) = original(shift*w^i)`. + let altered_coeffs = coeffs + .iter() + .zip(powers_shift) + .map(|(&c, p)| builder.scalar_mul_ext_algebra(p, c)) + .collect::>(); + let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); + let altered_interpolant = PolynomialCoeffsExtAlgebraTarget(altered_coeffs); + + for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) + .into_iter() + .enumerate() + { + let value = vars.get_local_ext_algebra(self.wires_value(i)); + let point = builder.constant_extension(point); + let computed_value = altered_interpolant.eval_scalar(builder, point); + constraints.extend( + &builder + .sub_ext_algebra(value, computed_value) + .to_ext_target_array(), + ); + } + + let evaluation_point_powers = (1..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i))) + .collect::>(); + let evaluation_point = evaluation_point_powers[0]; + for i in 1..self.num_points() - 1 { + let neg_one_ext = builder.neg_one_extension(); + let neg_new_power = + builder.scalar_mul_ext_algebra(neg_one_ext, evaluation_point_powers[i]); + let constraint = builder.mul_add_ext_algebra( + evaluation_point, + evaluation_point_powers[i - 1], + neg_new_power, + ); + constraints.extend(constraint.to_ext_target_array()); + } + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + let computed_evaluation_value = + interpolant.eval_with_powers(builder, &evaluation_point_powers); + // let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + // let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + // let computed_evaluation_value = interpolant.eval(builder, evaluation_point); + constraints.extend( + &builder + .sub_ext_algebra(evaluation_value, computed_evaluation_value) + .to_ext_target_array(), + ); + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + let gen = InterpolationGenerator:: { + gate_index, + gate: *self, + _phantom: PhantomData, + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.end() + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 2 + } + + fn num_constraints(&self) -> usize { + // `num_points * D` constraints to check for consistency between the coefficients and the + // point-value pairs, plus `D` constraints for the evaluation value, plus `(D+1)*(num_points-2)` + // to check power constraints for evaluation point and shift. + self.num_points() * D + D + (D + 1) * (self.num_points() - 2) + } +} + +#[derive(Debug)] +struct InterpolationGenerator, const D: usize> { + gate_index: usize, + gate: LowDegreeInterpolationGate, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for InterpolationGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |input| { + Target::Wire(Wire { + gate: self.gate_index, + input, + }) + }; + + let local_targets = |inputs: Range| inputs.map(local_target); + + let num_points = self.gate.num_points(); + let mut deps = Vec::with_capacity(1 + D + num_points * D); + + deps.push(local_target(self.gate.wire_shift())); + deps.extend(local_targets(self.gate.wires_evaluation_point())); + for i in 0..num_points { + deps.extend(local_targets(self.gate.wires_value(i))); + } + deps + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let get_local_wire = |input| witness.get_wire(local_wire(input)); + + let get_local_ext = |wire_range: Range| { + debug_assert_eq!(wire_range.len(), D); + let values = wire_range.map(get_local_wire).collect::>(); + let arr = values.try_into().unwrap(); + F::Extension::from_basefield_array(arr) + }; + + let wire_shift = get_local_wire(self.gate.wire_shift()); + + for (i, power) in wire_shift + .powers() + .take(self.gate.num_points()) + .enumerate() + .skip(2) + { + out_buffer.set_wire(local_wire(self.gate.powers_shift(i)), power); + } + + // Compute the interpolant. + let points = self.gate.coset(wire_shift); + let points = points + .into_iter() + .enumerate() + .map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i)))) + .collect::>(); + let interpolant = interpolant(&points); + + for (i, &coeff) in interpolant.coeffs.iter().enumerate() { + let wires = self.gate.wires_coeff(i).map(local_wire); + out_buffer.set_ext_wires(wires, coeff); + } + + let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); + for (i, power) in evaluation_point + .powers() + .take(self.gate.num_points()) + .enumerate() + .skip(2) + { + out_buffer.set_extension_target( + ExtensionTarget::from_range(self.gate_index, self.gate.powers_evaluation_point(i)), + power, + ); + } + let evaluation_value = interpolant.eval(evaluation_point); + let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); + out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::extension_field::quadratic::QuadraticExtension; + use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; + use crate::gadgets::interpolation::InterpolationGate; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; + use crate::hash::hash_types::HashOut; + use crate::plonk::vars::EvaluationVars; + use crate::polynomial::PolynomialCoeffs; + + #[test] + fn low_degree() { + test_low_degree::(LowDegreeInterpolationGate::new(4)); + } + + #[test] + fn eval_fns() -> Result<()> { + test_eval_fns::(LowDegreeInterpolationGate::new(4)) + } + + #[test] + fn test_gate_constraint() { + type F = GoldilocksField; + type FF = QuadraticExtension; + const D: usize = 2; + + /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. + fn get_wires( + gate: &LowDegreeInterpolationGate, + shift: F, + coeffs: PolynomialCoeffs, + eval_point: FF, + ) -> Vec { + let points = gate.coset(shift); + let mut v = vec![shift]; + for x in points { + v.extend(coeffs.eval(x.into()).0); + } + v.extend(eval_point.0); + v.extend(coeffs.eval(eval_point).0); + for i in 0..coeffs.len() { + v.extend(coeffs.coeffs[i].0); + } + v.extend(shift.powers().skip(2).take(gate.num_points() - 2)); + v.extend( + eval_point + .powers() + .skip(2) + .take(gate.num_points() - 2) + .flat_map(|ff| ff.0), + ); + v.iter().map(|&x| x.into()).collect::>() + } + + // Get a working row for LowDegreeInterpolationGate. + let subgroup_bits = 4; + let shift = F::rand(); + let coeffs = PolynomialCoeffs::new(FF::rand_vec(1 << subgroup_bits)); + let eval_point = FF::rand(); + let gate = LowDegreeInterpolationGate::::new(subgroup_bits); + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(&gate, shift, coeffs, eval_point), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 76066285..54289733 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -1,8 +1,10 @@ // Gates have `new` methods that return `GateRef`s. #![allow(clippy::new_ret_no_self)] -pub mod arithmetic; +pub mod arithmetic_base; +pub mod arithmetic_extension; pub mod arithmetic_u32; +pub mod assert_le; pub mod base_sum; pub mod comparison; pub mod constant; @@ -12,12 +14,16 @@ pub mod gate_tree; pub mod gmimc; pub mod insertion; pub mod interpolation; +pub mod low_degree_interpolation; +pub mod multiplication_extension; pub mod noop; pub mod poseidon; pub(crate) mod poseidon_mds; pub(crate) mod public_input; pub mod random_access; pub mod reducing; +pub mod reducing_extension; +pub mod subtraction_u32; pub mod switch; #[cfg(test)] diff --git a/src/gates/multiplication_extension.rs b/src/gates/multiplication_extension.rs new file mode 100644 index 00000000..4c385b79 --- /dev/null +++ b/src/gates/multiplication_extension.rs @@ -0,0 +1,204 @@ +use std::ops::Range; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::extension_field::FieldExtension; +use crate::field::field_types::RichField; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// A gate which can perform a weighted multiplication, i.e. `result = c0 x y`. If the config +/// supports enough routed wires, it can support several such operations in one gate. +#[derive(Debug)] +pub struct MulExtensionGate { + /// Number of multiplications performed by the gate. + pub num_ops: usize, +} + +impl MulExtensionGate { + pub fn new_from_config(config: &CircuitConfig) -> Self { + Self { + num_ops: Self::num_ops(config), + } + } + + /// Determine the maximum number of operations that can fit in one gate for the given config. + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 3 * D; + config.num_routed_wires / wires_per_op + } + + pub fn wires_ith_multiplicand_0(i: usize) -> Range { + 3 * D * i..3 * D * i + D + } + pub fn wires_ith_multiplicand_1(i: usize) -> Range { + 3 * D * i + D..3 * D * i + 2 * D + } + pub fn wires_ith_output(i: usize) -> Range { + 3 * D * i + 2 * D..3 * D * i + 3 * D + } +} + +impl, const D: usize> Gate for MulExtensionGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let const_0 = vars.local_constants[0]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i)); + let output = vars.get_local_ext_algebra(Self::wires_ith_output(i)); + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0); + + constraints.extend((output - computed_output).to_basefield_array()); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let const_0 = vars.local_constants[0]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i)); + let output = vars.get_local_ext(Self::wires_ith_output(i)); + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0); + + constraints.extend((output - computed_output).to_basefield_array()); + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let const_0 = vars.local_constants[0]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i)); + let output = vars.get_local_ext_algebra(Self::wires_ith_output(i)); + let computed_output = { + let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1); + builder.scalar_mul_ext_algebra(const_0, mul) + }; + + let diff = builder.sub_ext_algebra(output, computed_output); + constraints.extend(diff.to_ext_target_array()); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + local_constants: &[F], + ) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + MulExtensionGenerator { + gate_index, + const_0: local_constants[0], + i, + } + .adapter(), + ); + g + }) + .collect::>() + } + + fn num_wires(&self) -> usize { + self.num_ops * 3 * D + } + + fn num_constants(&self) -> usize { + 1 + } + + fn degree(&self) -> usize { + 3 + } + + fn num_constraints(&self) -> usize { + self.num_ops * D + } +} + +#[derive(Clone, Debug)] +struct MulExtensionGenerator, const D: usize> { + gate_index: usize, + const_0: F, + i: usize, +} + +impl, const D: usize> SimpleGenerator + for MulExtensionGenerator +{ + fn dependencies(&self) -> Vec { + MulExtensionGate::::wires_ith_multiplicand_0(self.i) + .chain(MulExtensionGate::::wires_ith_multiplicand_1(self.i)) + .map(|i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) + }; + + let multiplicand_0 = + extract_extension(MulExtensionGate::::wires_ith_multiplicand_0(self.i)); + let multiplicand_1 = + extract_extension(MulExtensionGate::::wires_ith_multiplicand_1(self.i)); + + let output_target = ExtensionTarget::from_range( + self.gate_index, + MulExtensionGate::::wires_ith_output(self.i), + ); + + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(self.const_0); + + out_buffer.set_extension_target(output_target, computed_output) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::multiplication_extension::MulExtensionGate; + use crate::plonk::circuit_data::CircuitConfig; + + #[test] + fn low_degree() { + let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_low_degree::(gate); + } + + #[test] + fn eval_fns() -> Result<()> { + let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_eval_fns::(gate) + } +} diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 31b29bda..6d9d97d5 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; @@ -47,44 +46,59 @@ impl, const D: usize> PoseidonGate { /// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0. pub const WIRE_SWAP: usize = 2 * SPONGE_WIDTH; + const START_DELTA: usize = 2 * WIDTH + 1; + + /// A wire which stores `swap * (input[i + 4] - input[i])`; used to compute the swapped inputs. + fn wire_delta(i: usize) -> usize { + assert!(i < 4); + Self::START_DELTA + i + } + + const START_FULL_0: usize = Self::START_DELTA + 4; + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set /// of full rounds. fn wire_full_sbox_0(round: usize, i: usize) -> usize { + debug_assert!( + round != 0, + "First round S-box inputs are not stored as wires" + ); debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - 2 * SPONGE_WIDTH + 1 + SPONGE_WIDTH * round + i + debug_assert!(i < WIDTH); + Self::START_FULL_0 + WIDTH * (round - 1) + i } + const START_PARTIAL: usize = Self::START_FULL_0 + WIDTH * (poseidon::HALF_N_FULL_ROUNDS - 1); + /// A wire which stores the input of the S-box of the `round`-th round of the partial rounds. fn wire_partial_sbox(round: usize) -> usize { debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); - 2 * SPONGE_WIDTH + 1 + SPONGE_WIDTH * poseidon::HALF_N_FULL_ROUNDS + round + Self::START_PARTIAL + round } + const START_FULL_1: usize = Self::START_PARTIAL + poseidon::N_PARTIAL_ROUNDS; + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set /// of full rounds. fn wire_full_sbox_1(round: usize, i: usize) -> usize { debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - 2 * SPONGE_WIDTH - + 1 - + SPONGE_WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round) - + poseidon::N_PARTIAL_ROUNDS - + i + debug_assert!(i < WIDTH); + Self::START_FULL_1 + WIDTH * round + i } /// End of wire indices, exclusive. fn end() -> usize { - 2 * SPONGE_WIDTH - + 1 - + SPONGE_WIDTH * poseidon::N_FULL_ROUNDS_TOTAL - + poseidon::N_PARTIAL_ROUNDS + Self::START_FULL_1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS } } -impl, const D: usize> Gate for PoseidonGate { +impl + Poseidon, const D: usize, const WIDTH: usize> Gate + for PoseidonGate +where + [(); WIDTH - 1]:, +{ fn id(&self) -> String { - format!("{:?}", self, SPONGE_WIDTH) + format!("{:?}", self, WIDTH) } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { @@ -94,69 +108,79 @@ impl, const D: usize> Gate for PoseidonGate { let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * (swap - F::Extension::ONE)); - let mut state = Vec::with_capacity(SPONGE_WIDTH); + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - state.push(a + swap * (b - a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - state.push(a + swap * (b - a)); - } - for i in 8..SPONGE_WIDTH { - state.push(vars.local_wires[i]); + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + constraints.push(swap * (input_rhs - input_lhs) - delta_i); + } + + // Compute the possibly-swapped input layer. + let mut state = [F::Extension::ZERO; WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = Self::wire_input(i); + let input_rhs = Self::wire_input(i + 4); + state[i] = vars.local_wires[input_lhs] + delta_i; + state[i + 4] = vars.local_wires[input_rhs] - delta_i; + } + for i in 8..WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; } - let mut state: [F::Extension; SPONGE_WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - ::constant_layer_field(&mut state, round_ctr); - for i in 0..SPONGE_WIDTH { - let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(state[i] - sbox_in); - state[i] = sbox_in; + >::constant_layer_field(&mut state, round_ctr); + if r != 0 { + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } } - ::sbox_layer_field(&mut state); - state = ::mds_layer_field(&state); + >::sbox_layer_field(&mut state); + state = >::mds_layer_field(&state); round_ctr += 1; } // Partial rounds. - ::partial_first_constant_layer(&mut state); - state = ::mds_partial_layer_init(&mut state); + >::partial_first_constant_layer(&mut state); + state = >::mds_partial_layer_init(&state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(state[0] - sbox_in); - state[0] = ::sbox_monomial(sbox_in); - state[0] += - F::Extension::from_canonical_u64(::FAST_PARTIAL_ROUND_CONSTANTS[r]); - state = ::mds_partial_layer_fast_field(&state, r); + state[0] = >::sbox_monomial(sbox_in); + state[0] += F::Extension::from_canonical_u64( + >::FAST_PARTIAL_ROUND_CONSTANTS[r], + ); + state = >::mds_partial_layer_fast_field(&state, r); } let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; constraints.push(state[0] - sbox_in); - state[0] = ::sbox_monomial(sbox_in); - state = - ::mds_partial_layer_fast_field(&state, poseidon::N_PARTIAL_ROUNDS - 1); + state[0] = >::sbox_monomial(sbox_in); + state = >::mds_partial_layer_fast_field( + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); round_ctr += poseidon::N_PARTIAL_ROUNDS; // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - ::constant_layer_field(&mut state, round_ctr); - for i in 0..SPONGE_WIDTH { + >::constant_layer_field(&mut state, round_ctr); + for i in 0..WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - ::sbox_layer_field(&mut state); - state = ::mds_layer_field(&state); + >::sbox_layer_field(&mut state); + state = >::mds_layer_field(&state); round_ctr += 1; } - for i in 0..SPONGE_WIDTH { + for i in 0..WIDTH { constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); } @@ -170,67 +194,76 @@ impl, const D: usize> Gate for PoseidonGate { let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * swap.sub_one()); - let mut state = Vec::with_capacity(SPONGE_WIDTH); + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - state.push(a + swap * (b - a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - state.push(a + swap * (b - a)); - } - for i in 8..SPONGE_WIDTH { - state.push(vars.local_wires[i]); + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + constraints.push(swap * (input_rhs - input_lhs) - delta_i); + } + + // Compute the possibly-swapped input layer. + let mut state = [F::ZERO; WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = Self::wire_input(i); + let input_rhs = Self::wire_input(i + 4); + state[i] = vars.local_wires[input_lhs] + delta_i; + state[i + 4] = vars.local_wires[input_rhs] - delta_i; + } + for i in 8..WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; } - let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - ::constant_layer(&mut state, round_ctr); - for i in 0..SPONGE_WIDTH { - let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(state[i] - sbox_in); - state[i] = sbox_in; + >::constant_layer(&mut state, round_ctr); + if r != 0 { + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } } - ::sbox_layer(&mut state); - state = ::mds_layer(&state); + >::sbox_layer(&mut state); + state = >::mds_layer(&state); round_ctr += 1; } // Partial rounds. - ::partial_first_constant_layer(&mut state); - state = ::mds_partial_layer_init(&mut state); + >::partial_first_constant_layer(&mut state); + state = >::mds_partial_layer_init(&state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(state[0] - sbox_in); - state[0] = ::sbox_monomial(sbox_in); - state[0] += F::from_canonical_u64(::FAST_PARTIAL_ROUND_CONSTANTS[r]); - state = ::mds_partial_layer_fast(&state, r); + state[0] = >::sbox_monomial(sbox_in); + state[0] += + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]); + state = >::mds_partial_layer_fast(&state, r); } let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; constraints.push(state[0] - sbox_in); - state[0] = ::sbox_monomial(sbox_in); - state = ::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1); + state[0] = >::sbox_monomial(sbox_in); + state = + >::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1); round_ctr += poseidon::N_PARTIAL_ROUNDS; // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - ::constant_layer(&mut state, round_ctr); - for i in 0..SPONGE_WIDTH { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - ::sbox_layer(&mut state); - state = ::mds_layer(&state); + >::sbox_layer(&mut state); + state = >::mds_layer(&state); round_ctr += 1; } - for i in 0..SPONGE_WIDTH { + for i in 0..WIDTH { constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); } @@ -244,7 +277,7 @@ impl, const D: usize> Gate for PoseidonGate { ) -> Vec> { // The naive method is more efficient if we have enough routed wires for PoseidonMdsGate. let use_mds_gate = - builder.config.num_routed_wires >= PoseidonMdsGate::::new().num_wires(); + builder.config.num_routed_wires >= PoseidonMdsGate::::new().num_wires(); let mut constraints = Vec::with_capacity(self.num_constraints()); @@ -252,71 +285,73 @@ impl, const D: usize> Gate for PoseidonGate { let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(builder.mul_sub_extension(swap, swap, swap)); - let mut state = Vec::with_capacity(SPONGE_WIDTH); - // We need to compute both `if swap {b} else {a}` and `if swap {a} else {b}`. - // We will arithmetize them as - // swap (b - a) + a - // -swap (b - a) + b - // so that `b - a` can be used for both. - let mut state_first_4 = vec![]; - let mut state_next_4 = vec![]; + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - let delta = builder.sub_extension(b, a); - state_first_4.push(builder.mul_add_extension(swap, delta, a)); - state_next_4.push(builder.arithmetic_extension(F::NEG_ONE, F::ONE, swap, delta, b)); + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let diff = builder.sub_extension(input_rhs, input_lhs); + constraints.push(builder.mul_sub_extension(swap, diff, delta_i)); } - state.extend(state_first_4); - state.extend(state_next_4); - for i in 8..SPONGE_WIDTH { - state.push(vars.local_wires[i]); + // Compute the possibly-swapped input layer. + let mut state = [builder.zero_extension(); WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + state[i] = builder.add_extension(input_lhs, delta_i); + state[i + 4] = builder.sub_extension(input_rhs, delta_i); + } + for i in 8..WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; } - let mut state: [ExtensionTarget; SPONGE_WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - ::constant_layer_recursive(builder, &mut state, round_ctr); - for i in 0..SPONGE_WIDTH { - let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(builder.sub_extension(state[i], sbox_in)); - state[i] = sbox_in; + >::constant_layer_recursive(builder, &mut state, round_ctr); + if r != 0 { + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(builder.sub_extension(state[i], sbox_in)); + state[i] = sbox_in; + } } - ::sbox_layer_recursive(builder, &mut state); - state = ::mds_layer_recursive(builder, &state); + >::sbox_layer_recursive(builder, &mut state); + state = >::mds_layer_recursive(builder, &state); round_ctr += 1; } // Partial rounds. if use_mds_gate { for r in 0..poseidon::N_PARTIAL_ROUNDS { - ::constant_layer_recursive(builder, &mut state, round_ctr); + >::constant_layer_recursive(builder, &mut state, round_ctr); let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(builder.sub_extension(state[0], sbox_in)); - state[0] = ::sbox_monomial_recursive(builder, sbox_in); - state = ::mds_layer_recursive(builder, &state); + state[0] = >::sbox_monomial_recursive(builder, sbox_in); + state = >::mds_layer_recursive(builder, &state); round_ctr += 1; } } else { - ::partial_first_constant_layer_recursive(builder, &mut state); - state = ::mds_partial_layer_init_recursive(builder, &mut state); + >::partial_first_constant_layer_recursive(builder, &mut state); + state = >::mds_partial_layer_init_recursive(builder, &state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(builder.sub_extension(state[0], sbox_in)); - state[0] = ::sbox_monomial_recursive(builder, sbox_in); - state[0] = builder.add_const_extension( - state[0], - F::from_canonical_u64(::FAST_PARTIAL_ROUND_CONSTANTS[r]), - ); - state = ::mds_partial_layer_fast_recursive(builder, &state, r); + state[0] = >::sbox_monomial_recursive(builder, sbox_in); + let c = >::FAST_PARTIAL_ROUND_CONSTANTS[r]; + let c = F::Extension::from_canonical_u64(c); + let c = builder.constant_extension(c); + state[0] = builder.add_extension(state[0], c); + state = + >::mds_partial_layer_fast_recursive(builder, &state, r); } let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; constraints.push(builder.sub_extension(state[0], sbox_in)); - state[0] = ::sbox_monomial_recursive(builder, sbox_in); - state = ::mds_partial_layer_fast_recursive( + state[0] = >::sbox_monomial_recursive(builder, sbox_in); + state = >::mds_partial_layer_fast_recursive( builder, &state, poseidon::N_PARTIAL_ROUNDS - 1, @@ -326,18 +361,18 @@ impl, const D: usize> Gate for PoseidonGate { // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - ::constant_layer_recursive(builder, &mut state, round_ctr); - for i in 0..SPONGE_WIDTH { + >::constant_layer_recursive(builder, &mut state, round_ctr); + for i in 0..WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; constraints.push(builder.sub_extension(state[i], sbox_in)); state[i] = sbox_in; } - ::sbox_layer_recursive(builder, &mut state); - state = ::mds_layer_recursive(builder, &state); + >::sbox_layer_recursive(builder, &mut state); + state = >::mds_layer_recursive(builder, &state); round_ctr += 1; } - for i in 0..SPONGE_WIDTH { + for i in 0..WIDTH { constraints .push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)])); } @@ -350,7 +385,7 @@ impl, const D: usize> Gate for PoseidonGate { gate_index: usize, _local_constants: &[F], ) -> Vec>> { - let gen = PoseidonGenerator:: { + let gen = PoseidonGenerator:: { gate_index, _phantom: PhantomData, }; @@ -370,23 +405,31 @@ impl, const D: usize> Gate for PoseidonGate { } fn num_constraints(&self) -> usize { - SPONGE_WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + SPONGE_WIDTH + 1 + WIDTH * (poseidon::N_FULL_ROUNDS_TOTAL - 1) + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + 4 } } #[derive(Debug)] -struct PoseidonGenerator + Poseidon, const D: usize> { +struct PoseidonGenerator< + F: RichField + Extendable + Poseidon, + const D: usize, + const WIDTH: usize, +> where + [(); WIDTH - 1]:, +{ gate_index: usize, _phantom: PhantomData, } -impl + Poseidon, const D: usize> SimpleGenerator - for PoseidonGenerator +impl + Poseidon, const D: usize, const WIDTH: usize> + SimpleGenerator for PoseidonGenerator +where + [(); WIDTH - 1]:, { fn dependencies(&self) -> Vec { - (0..SPONGE_WIDTH) - .map(|i| PoseidonGate::::wire_input(i)) - .chain(Some(PoseidonGate::::WIRE_SWAP)) + (0..WIDTH) + .map(|i| PoseidonGate::::wire_input(i)) + .chain(Some(PoseidonGate::::WIRE_SWAP)) .map(|input| Target::wire(self.gate_index, input)) .collect() } @@ -397,87 +440,94 @@ impl + Poseidon, const D: usize> SimpleGenerator input, }; - let mut state = (0..SPONGE_WIDTH) - .map(|i| { - witness.get_wire(Wire { - gate: self.gate_index, - input: PoseidonGate::::wire_input(i), - }) - }) + let mut state = (0..WIDTH) + .map(|i| witness.get_wire(local_wire(PoseidonGate::::wire_input(i)))) .collect::>(); - let swap_value = witness.get_wire(Wire { - gate: self.gate_index, - input: PoseidonGate::::WIRE_SWAP, - }); + let swap_value = witness.get_wire(local_wire(PoseidonGate::::WIRE_SWAP)); debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); + + for i in 0..4 { + let delta_i = swap_value * (state[i + 4] - state[i]); + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_delta(i)), + delta_i, + ); + } + if swap_value == F::ONE { for i in 0..4 { state.swap(i, 4 + i); } } - let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap(); + let mut state: [F; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; for r in 0..poseidon::HALF_N_FULL_ROUNDS { - ::constant_layer_field(&mut state, round_ctr); - for i in 0..SPONGE_WIDTH { - out_buffer.set_wire( - local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), - state[i], - ); + >::constant_layer_field(&mut state, round_ctr); + if r != 0 { + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), + state[i], + ); + } } - ::sbox_layer_field(&mut state); - state = ::mds_layer_field(&state); + >::sbox_layer_field(&mut state); + state = >::mds_layer_field(&state); round_ctr += 1; } - ::partial_first_constant_layer(&mut state); - state = ::mds_partial_layer_init(&mut state); + >::partial_first_constant_layer(&mut state); + state = >::mds_partial_layer_init(&state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { out_buffer.set_wire( - local_wire(PoseidonGate::::wire_partial_sbox(r)), + local_wire(PoseidonGate::::wire_partial_sbox(r)), state[0], ); - state[0] = ::sbox_monomial(state[0]); - state[0] += F::from_canonical_u64(::FAST_PARTIAL_ROUND_CONSTANTS[r]); - state = ::mds_partial_layer_fast_field(&state, r); + state[0] = >::sbox_monomial(state[0]); + state[0] += + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]); + state = >::mds_partial_layer_fast_field(&state, r); } out_buffer.set_wire( - local_wire(PoseidonGate::::wire_partial_sbox( + local_wire(PoseidonGate::::wire_partial_sbox( poseidon::N_PARTIAL_ROUNDS - 1, )), state[0], ); - state[0] = ::sbox_monomial(state[0]); - state = - ::mds_partial_layer_fast_field(&state, poseidon::N_PARTIAL_ROUNDS - 1); + state[0] = >::sbox_monomial(state[0]); + state = >::mds_partial_layer_fast_field( + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); round_ctr += poseidon::N_PARTIAL_ROUNDS; for r in 0..poseidon::HALF_N_FULL_ROUNDS { - ::constant_layer_field(&mut state, round_ctr); - for i in 0..SPONGE_WIDTH { + >::constant_layer_field(&mut state, round_ctr); + for i in 0..WIDTH { out_buffer.set_wire( - local_wire(PoseidonGate::::wire_full_sbox_1(r, i)), + local_wire(PoseidonGate::::wire_full_sbox_1(r, i)), state[i], ); } - ::sbox_layer_field(&mut state); - state = ::mds_layer_field(&state); + >::sbox_layer_field(&mut state); + state = >::mds_layer_field(&state); round_ctr += 1; } - for i in 0..SPONGE_WIDTH { - out_buffer.set_wire(local_wire(PoseidonGate::::wire_output(i)), state[i]); + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_output(i)), + state[i], + ); } } } #[cfg(test)] mod tests { - use std::convert::TryInto; - use anyhow::Result; use crate::field::field_types::Field; @@ -493,6 +543,29 @@ mod tests { use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + #[test] + fn wire_indices() { + type F = GoldilocksField; + const WIDTH: usize = 12; + type Gate = PoseidonGate; + + assert_eq!(Gate::wire_input(0), 0); + assert_eq!(Gate::wire_input(11), 11); + assert_eq!(Gate::wire_output(0), 12); + assert_eq!(Gate::wire_output(11), 23); + assert_eq!(Gate::WIRE_SWAP, 24); + assert_eq!(Gate::wire_delta(0), 25); + assert_eq!(Gate::wire_delta(3), 28); + assert_eq!(Gate::wire_full_sbox_0(1, 0), 29); + assert_eq!(Gate::wire_full_sbox_0(3, 0), 53); + assert_eq!(Gate::wire_full_sbox_0(3, 11), 64); + assert_eq!(Gate::wire_partial_sbox(0), 65); + assert_eq!(Gate::wire_partial_sbox(21), 86); + assert_eq!(Gate::wire_full_sbox_1(0, 0), 87); + assert_eq!(Gate::wire_full_sbox_1(3, 0), 123); + assert_eq!(Gate::wire_full_sbox_1(3, 11), 134); + } + #[test] fn generated_output() { const D: usize = 2; diff --git a/src/gates/poseidon_mds.rs b/src/gates/poseidon_mds.rs index 32fb1da1..1abbe71f 100644 --- a/src/gates/poseidon_mds.rs +++ b/src/gates/poseidon_mds.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::marker::PhantomData; use std::ops::Range; @@ -6,9 +5,8 @@ use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::Extendable; use crate::field::extension_field::FieldExtension; -use crate::field::field_types::Field; +use crate::field::field_types::{Field, RichField}; use crate::gates::gate::Gate; -use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::poseidon::Poseidon; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; @@ -17,11 +15,21 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; #[derive(Debug)] -pub struct PoseidonMdsGate + Poseidon, const D: usize> { +pub struct PoseidonMdsGate< + F: RichField + Extendable + Poseidon, + const D: usize, + const WIDTH: usize, +> where + [(); WIDTH - 1]:, +{ _phantom: PhantomData, } -impl + Poseidon, const D: usize> PoseidonMdsGate { +impl + Poseidon, const D: usize, const WIDTH: usize> + PoseidonMdsGate +where + [(); WIDTH - 1]:, +{ pub fn new() -> Self { PoseidonMdsGate { _phantom: PhantomData, @@ -29,13 +37,13 @@ impl + Poseidon, const D: usize> PoseidonMdsGate { } pub fn wires_input(i: usize) -> Range { - assert!(i < SPONGE_WIDTH); + assert!(i < WIDTH); i * D..(i + 1) * D } pub fn wires_output(i: usize) -> Range { - assert!(i < SPONGE_WIDTH); - (SPONGE_WIDTH + i) * D..(SPONGE_WIDTH + i + 1) * D + assert!(i < WIDTH); + (WIDTH + i) * D..(WIDTH + i + 1) * D } // Following are methods analogous to ones in `Poseidon`, but for extension algebras. @@ -43,14 +51,15 @@ impl + Poseidon, const D: usize> PoseidonMdsGate { /// Same as `mds_row_shf` for an extension algebra of `F`. fn mds_row_shf_algebra( r: usize, - v: &[ExtensionAlgebra; SPONGE_WIDTH], + v: &[ExtensionAlgebra; WIDTH], ) -> ExtensionAlgebra { - debug_assert!(r < SPONGE_WIDTH); + debug_assert!(r < WIDTH); let mut res = ExtensionAlgebra::ZERO; - for i in 0..SPONGE_WIDTH { - let coeff = F::Extension::from_canonical_u64(1 << ::MDS_MATRIX_EXPS[i]); - res += v[(i + r) % SPONGE_WIDTH].scalar_mul(coeff); + for i in 0..WIDTH { + let coeff = + F::Extension::from_canonical_u64(1 << >::MDS_MATRIX_EXPS[i]); + res += v[(i + r) % WIDTH].scalar_mul(coeff); } res @@ -60,16 +69,16 @@ impl + Poseidon, const D: usize> PoseidonMdsGate { fn mds_row_shf_algebra_recursive( builder: &mut CircuitBuilder, r: usize, - v: &[ExtensionAlgebraTarget; SPONGE_WIDTH], + v: &[ExtensionAlgebraTarget; WIDTH], ) -> ExtensionAlgebraTarget { - debug_assert!(r < SPONGE_WIDTH); + debug_assert!(r < WIDTH); let mut res = builder.zero_ext_algebra(); - for i in 0..SPONGE_WIDTH { + for i in 0..WIDTH { let coeff = builder.constant_extension(F::Extension::from_canonical_u64( - 1 << ::MDS_MATRIX_EXPS[i], + 1 << >::MDS_MATRIX_EXPS[i], )); - res = builder.scalar_mul_add_ext_algebra(coeff, v[(i + r) % SPONGE_WIDTH], res); + res = builder.scalar_mul_add_ext_algebra(coeff, v[(i + r) % WIDTH], res); } res @@ -77,11 +86,11 @@ impl + Poseidon, const D: usize> PoseidonMdsGate { /// Same as `mds_layer` for an extension algebra of `F`. fn mds_layer_algebra( - state: &[ExtensionAlgebra; SPONGE_WIDTH], - ) -> [ExtensionAlgebra; SPONGE_WIDTH] { - let mut result = [ExtensionAlgebra::ZERO; SPONGE_WIDTH]; + state: &[ExtensionAlgebra; WIDTH], + ) -> [ExtensionAlgebra; WIDTH] { + let mut result = [ExtensionAlgebra::ZERO; WIDTH]; - for r in 0..SPONGE_WIDTH { + for r in 0..WIDTH { result[r] = Self::mds_row_shf_algebra(r, state); } @@ -91,11 +100,11 @@ impl + Poseidon, const D: usize> PoseidonMdsGate { /// Same as `mds_layer_recursive` for an extension algebra of `F`. fn mds_layer_algebra_recursive( builder: &mut CircuitBuilder, - state: &[ExtensionAlgebraTarget; SPONGE_WIDTH], - ) -> [ExtensionAlgebraTarget; SPONGE_WIDTH] { - let mut result = [builder.zero_ext_algebra(); SPONGE_WIDTH]; + state: &[ExtensionAlgebraTarget; WIDTH], + ) -> [ExtensionAlgebraTarget; WIDTH] { + let mut result = [builder.zero_ext_algebra(); WIDTH]; - for r in 0..SPONGE_WIDTH { + for r in 0..WIDTH { result[r] = Self::mds_row_shf_algebra_recursive(builder, r, state); } @@ -103,13 +112,17 @@ impl + Poseidon, const D: usize> PoseidonMdsGate { } } -impl + Poseidon, const D: usize> Gate for PoseidonMdsGate { +impl + Poseidon, const D: usize, const WIDTH: usize> Gate + for PoseidonMdsGate +where + [(); WIDTH - 1]:, +{ fn id(&self) -> String { - format!("{:?}", self, SPONGE_WIDTH) + format!("{:?}", self, WIDTH) } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) + let inputs: [_; WIDTH] = (0..WIDTH) .map(|i| vars.get_local_ext_algebra(Self::wires_input(i))) .collect::>() .try_into() @@ -117,7 +130,7 @@ impl + Poseidon, const D: usize> Gate for PoseidonMdsGate let computed_outputs = Self::mds_layer_algebra(&inputs); - (0..SPONGE_WIDTH) + (0..WIDTH) .map(|i| vars.get_local_ext_algebra(Self::wires_output(i))) .zip(computed_outputs) .flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array()) @@ -125,7 +138,7 @@ impl + Poseidon, const D: usize> Gate for PoseidonMdsGate } fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) + let inputs: [_; WIDTH] = (0..WIDTH) .map(|i| vars.get_local_ext(Self::wires_input(i))) .collect::>() .try_into() @@ -133,7 +146,7 @@ impl + Poseidon, const D: usize> Gate for PoseidonMdsGate let computed_outputs = F::mds_layer_field(&inputs); - (0..SPONGE_WIDTH) + (0..WIDTH) .map(|i| vars.get_local_ext(Self::wires_output(i))) .zip(computed_outputs) .flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array()) @@ -145,7 +158,7 @@ impl + Poseidon, const D: usize> Gate for PoseidonMdsGate builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { - let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) + let inputs: [_; WIDTH] = (0..WIDTH) .map(|i| vars.get_local_ext_algebra(Self::wires_input(i))) .collect::>() .try_into() @@ -153,7 +166,7 @@ impl + Poseidon, const D: usize> Gate for PoseidonMdsGate let computed_outputs = Self::mds_layer_algebra_recursive(builder, &inputs); - (0..SPONGE_WIDTH) + (0..WIDTH) .map(|i| vars.get_local_ext_algebra(Self::wires_output(i))) .zip(computed_outputs) .flat_map(|(out, computed_out)| { @@ -169,12 +182,12 @@ impl + Poseidon, const D: usize> Gate for PoseidonMdsGate gate_index: usize, _local_constants: &[F], ) -> Vec>> { - let gen = PoseidonMdsGenerator:: { gate_index }; + let gen = PoseidonMdsGenerator:: { gate_index }; vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { - 2 * D * SPONGE_WIDTH + 2 * D * WIDTH } fn num_constants(&self) -> usize { @@ -186,20 +199,30 @@ impl + Poseidon, const D: usize> Gate for PoseidonMdsGate } fn num_constraints(&self) -> usize { - SPONGE_WIDTH * D + WIDTH * D } } #[derive(Clone, Debug)] -struct PoseidonMdsGenerator { +struct PoseidonMdsGenerator +where + [(); WIDTH - 1]:, +{ gate_index: usize, } -impl + Poseidon, const D: usize> SimpleGenerator for PoseidonMdsGenerator { +impl + Poseidon, const D: usize, const WIDTH: usize> + SimpleGenerator for PoseidonMdsGenerator +where + [(); WIDTH - 1]:, +{ fn dependencies(&self) -> Vec { - (0..SPONGE_WIDTH) + (0..WIDTH) .flat_map(|i| { - Target::wires_from_range(self.gate_index, PoseidonMdsGate::::wires_input(i)) + Target::wires_from_range( + self.gate_index, + PoseidonMdsGate::::wires_input(i), + ) }) .collect() } @@ -210,8 +233,8 @@ impl + Poseidon, const D: usize> SimpleGenerator for Poseido let get_local_ext = |wire_range| witness.get_extension_target(get_local_get_target(wire_range)); - let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) - .map(|i| get_local_ext(PoseidonMdsGate::::wires_input(i))) + let inputs: [_; WIDTH] = (0..WIDTH) + .map(|i| get_local_ext(PoseidonMdsGate::::wires_input(i))) .collect::>() .try_into() .unwrap(); @@ -220,7 +243,7 @@ impl + Poseidon, const D: usize> SimpleGenerator for Poseido for (i, &out) in outputs.iter().enumerate() { out_buffer.set_extension_target( - get_local_get_target(PoseidonMdsGate::::wires_output(i)), + get_local_get_target(PoseidonMdsGate::::wires_output(i)), out, ); } @@ -232,21 +255,19 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::poseidon_mds::PoseidonMdsGate; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::hash::hashing::SPONGE_WIDTH; #[test] fn low_degree() { type F = GoldilocksField; - let gate = PoseidonMdsGate::::new(); + let gate = PoseidonMdsGate::::new(); test_low_degree(gate) } #[test] fn eval_fns() -> anyhow::Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let gate = PoseidonMdsGate::::new(); - test_eval_fns::(gate) + type F = GoldilocksField; + let gate = PoseidonMdsGate::::new(); + test_eval_fns(gate) } } diff --git a/src/gates/random_access.rs b/src/gates/random_access.rs index d29cbed6..41a14288 100644 --- a/src/gates/random_access.rs +++ b/src/gates/random_access.rs @@ -1,5 +1,7 @@ use std::marker::PhantomData; +use itertools::Itertools; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; @@ -14,76 +16,65 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// A gate for checking that a particular element of a list matches a given value. #[derive(Copy, Clone, Debug)] -pub(crate) struct RandomAccessGate, const D: usize> { - pub vec_size: usize, +pub(crate) struct RandomAccessGate, const D: usize> { + pub bits: usize, pub num_copies: usize, _phantom: PhantomData, } -impl, const D: usize> RandomAccessGate { - pub fn new(num_copies: usize, vec_size: usize) -> Self { +impl, const D: usize> RandomAccessGate { + fn new(num_copies: usize, bits: usize) -> Self { Self { - vec_size, + bits, num_copies, _phantom: PhantomData, } } - pub fn new_from_config(config: &CircuitConfig, vec_size: usize) -> Self { - let num_copies = Self::max_num_copies(config.num_routed_wires, config.num_wires, vec_size); - Self::new(num_copies, vec_size) + pub fn new_from_config(config: &CircuitConfig, bits: usize) -> Self { + let vec_size = 1 << bits; + // Need `(2 + vec_size) * num_copies` routed wires + let max_copies = (config.num_routed_wires / (2 + vec_size)).min( + // Need `(2 + vec_size + bits) * num_copies` wires + config.num_wires / (2 + vec_size + bits), + ); + Self::new(max_copies, bits) } - pub fn max_num_copies(num_routed_wires: usize, num_wires: usize, vec_size: usize) -> usize { - // Need `(2 + vec_size) * num_copies` routed wires - (num_routed_wires / (2 + vec_size)).min( - // Need `(2 + 3*vec_size) * num_copies` wires - num_wires / (2 + 3 * vec_size), - ) + fn vec_size(&self) -> usize { + 1 << self.bits } pub fn wire_access_index(&self, copy: usize) -> usize { debug_assert!(copy < self.num_copies); - (2 + self.vec_size) * copy + (2 + self.vec_size()) * copy } pub fn wire_claimed_element(&self, copy: usize) -> usize { debug_assert!(copy < self.num_copies); - (2 + self.vec_size) * copy + 1 + (2 + self.vec_size()) * copy + 1 } pub fn wire_list_item(&self, i: usize, copy: usize) -> usize { - debug_assert!(i < self.vec_size); + debug_assert!(i < self.vec_size()); debug_assert!(copy < self.num_copies); - (2 + self.vec_size) * copy + 2 + i + (2 + self.vec_size()) * copy + 2 + i } fn start_of_intermediate_wires(&self) -> usize { - (2 + self.vec_size) * self.num_copies + (2 + self.vec_size()) * self.num_copies } pub(crate) fn num_routed_wires(&self) -> usize { self.start_of_intermediate_wires() } - /// An intermediate wire for a dummy variable used to show equality. - /// The prover sets this to 1/(x-y) if x != y, or to an arbitrary value if - /// x == y. - pub fn wire_equality_dummy_for_index(&self, i: usize, copy: usize) -> usize { - debug_assert!(i < self.vec_size); + /// An intermediate wire where the prover gives the (purported) binary decomposition of the + /// index. + pub fn wire_bit(&self, i: usize, copy: usize) -> usize { + debug_assert!(i < self.bits); debug_assert!(copy < self.num_copies); - self.start_of_intermediate_wires() + copy * self.vec_size + i - } - - /// An intermediate wire for the "index_matches" variable (1 if the current index is the index at - /// which to compare, 0 otherwise). - pub fn wire_index_matches_for_index(&self, i: usize, copy: usize) -> usize { - debug_assert!(i < self.vec_size); - debug_assert!(copy < self.num_copies); - self.start_of_intermediate_wires() - + self.vec_size * self.num_copies - + self.vec_size * copy - + i + self.start_of_intermediate_wires() + copy * self.bits + i } } @@ -97,23 +88,38 @@ impl, const D: usize> Gate for RandomAccessGate { for copy in 0..self.num_copies { let access_index = vars.local_wires[self.wire_access_index(copy)]; - let list_items = (0..self.vec_size) + let mut list_items = (0..self.vec_size()) .map(|i| vars.local_wires[self.wire_list_item(i, copy)]) .collect::>(); let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; + let bits = (0..self.bits) + .map(|i| vars.local_wires[self.wire_bit(i, copy)]) + .collect::>(); - for i in 0..self.vec_size { - let cur_index = F::Extension::from_canonical_usize(i); - let difference = cur_index - access_index; - let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)]; - let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)]; - - // The two index equality constraints. - constraints.push(difference * equality_dummy - (F::Extension::ONE - index_matches)); - constraints.push(index_matches * difference); - // Value equality constraint. - constraints.push((list_items[i] - claimed_element) * index_matches); + // Assert that each bit wire value is indeed boolean. + for &b in &bits { + constraints.push(b * (b - F::Extension::ONE)); } + + // Assert that the binary decomposition was correct. + let reconstructed_index = bits + .iter() + .rev() + .fold(F::Extension::ZERO, |acc, &b| acc.double() + b); + constraints.push(reconstructed_index - access_index); + + // Repeatedly fold the list, selecting the left or right item from each pair based on + // the corresponding bit. + for b in bits { + list_items = list_items + .iter() + .tuples() + .map(|(&x, &y)| x + b * (y - x)) + .collect() + } + + debug_assert_eq!(list_items.len(), 1); + constraints.push(list_items[0] - claimed_element); } constraints @@ -124,23 +130,35 @@ impl, const D: usize> Gate for RandomAccessGate { for copy in 0..self.num_copies { let access_index = vars.local_wires[self.wire_access_index(copy)]; - let list_items = (0..self.vec_size) + let mut list_items = (0..self.vec_size()) .map(|i| vars.local_wires[self.wire_list_item(i, copy)]) .collect::>(); let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; + let bits = (0..self.bits) + .map(|i| vars.local_wires[self.wire_bit(i, copy)]) + .collect::>(); - for i in 0..self.vec_size { - let cur_index = F::from_canonical_usize(i); - let difference = cur_index - access_index; - let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)]; - let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)]; - - // The two index equality constraints. - constraints.push(difference * equality_dummy - (F::ONE - index_matches)); - constraints.push(index_matches * difference); - // Value equality constraint. - constraints.push((list_items[i] - claimed_element) * index_matches); + // Assert that each bit wire value is indeed boolean. + for &b in &bits { + constraints.push(b * (b - F::ONE)); } + + // Assert that the binary decomposition was correct. + let reconstructed_index = bits.iter().rev().fold(F::ZERO, |acc, &b| acc.double() + b); + constraints.push(reconstructed_index - access_index); + + // Repeatedly fold the list, selecting the left or right item from each pair based on + // the corresponding bit. + for b in bits { + list_items = list_items + .iter() + .tuples() + .map(|(&x, &y)| x + b * (y - x)) + .collect() + } + + debug_assert_eq!(list_items.len(), 1); + constraints.push(list_items[0] - claimed_element); } constraints @@ -151,36 +169,44 @@ impl, const D: usize> Gate for RandomAccessGate { builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { + let zero = builder.zero_extension(); + let two = builder.two_extension(); let mut constraints = Vec::with_capacity(self.num_constraints()); for copy in 0..self.num_copies { let access_index = vars.local_wires[self.wire_access_index(copy)]; - let list_items = (0..self.vec_size) + let mut list_items = (0..self.vec_size()) .map(|i| vars.local_wires[self.wire_list_item(i, copy)]) .collect::>(); let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; + let bits = (0..self.bits) + .map(|i| vars.local_wires[self.wire_bit(i, copy)]) + .collect::>(); - for i in 0..self.vec_size { - let cur_index_ext = F::Extension::from_canonical_usize(i); - let cur_index = builder.constant_extension(cur_index_ext); - let difference = builder.sub_extension(cur_index, access_index); - let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)]; - let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)]; - - let one = builder.one_extension(); - let not_index_matches = builder.sub_extension(one, index_matches); - let first_equality_constraint = - builder.mul_sub_extension(difference, equality_dummy, not_index_matches); - constraints.push(first_equality_constraint); - - let second_equality_constraint = builder.mul_extension(index_matches, difference); - constraints.push(second_equality_constraint); - - // Output constraint. - let diff = builder.sub_extension(list_items[i], claimed_element); - let conditional_diff = builder.mul_extension(index_matches, diff); - constraints.push(conditional_diff); + // Assert that each bit wire value is indeed boolean. + for &b in &bits { + constraints.push(builder.mul_sub_extension(b, b, b)); } + + // Assert that the binary decomposition was correct. + let reconstructed_index = bits + .iter() + .rev() + .fold(zero, |acc, &b| builder.mul_add_extension(acc, two, b)); + constraints.push(builder.sub_extension(reconstructed_index, access_index)); + + // Repeatedly fold the list, selecting the left or right item from each pair based on + // the corresponding bit. + for b in bits { + list_items = list_items + .iter() + .tuples() + .map(|(&x, &y)| builder.select_ext_generalized(b, y, x)) + .collect() + } + + debug_assert_eq!(list_items.len(), 1); + constraints.push(builder.sub_extension(list_items[0], claimed_element)); } constraints @@ -207,7 +233,7 @@ impl, const D: usize> Gate for RandomAccessGate { } fn num_wires(&self) -> usize { - self.wire_index_matches_for_index(self.vec_size - 1, self.num_copies - 1) + 1 + self.wire_bit(self.bits - 1, self.num_copies - 1) + 1 } fn num_constants(&self) -> usize { @@ -215,11 +241,12 @@ impl, const D: usize> Gate for RandomAccessGate { } fn degree(&self) -> usize { - 2 + self.bits + 1 } fn num_constraints(&self) -> usize { - 3 * self.num_copies * self.vec_size + let constraints_per_copy = self.bits + 2; + self.num_copies * constraints_per_copy } } @@ -234,10 +261,8 @@ impl, const D: usize> SimpleGenerator for RandomAccessGenera fn dependencies(&self) -> Vec { let local_target = |input| Target::wire(self.gate_index, input); - let mut deps = Vec::new(); - deps.push(local_target(self.gate.wire_access_index(self.copy))); - deps.push(local_target(self.gate.wire_claimed_element(self.copy))); - for i in 0..self.gate.vec_size { + let mut deps = vec![local_target(self.gate.wire_access_index(self.copy))]; + for i in 0..self.gate.vec_size() { deps.push(local_target(self.gate.wire_list_item(i, self.copy))); } deps @@ -250,11 +275,12 @@ impl, const D: usize> SimpleGenerator for RandomAccessGenera }; let get_local_wire = |input| witness.get_wire(local_wire(input)); + let mut set_local_wire = |input, value| out_buffer.set_wire(local_wire(input), value); - // Compute the new vector and the values for equality_dummy and index_matches - let vec_size = self.gate.vec_size; - let access_index_f = get_local_wire(self.gate.wire_access_index(self.copy)); + let copy = self.copy; + let vec_size = self.gate.vec_size(); + let access_index_f = get_local_wire(self.gate.wire_access_index(copy)); let access_index = access_index_f.to_canonical_u64() as usize; debug_assert!( access_index < vec_size, @@ -263,22 +289,14 @@ impl, const D: usize> SimpleGenerator for RandomAccessGenera vec_size ); - for i in 0..vec_size { - let equality_dummy_wire = - local_wire(self.gate.wire_equality_dummy_for_index(i, self.copy)); - let index_matches_wire = - local_wire(self.gate.wire_index_matches_for_index(i, self.copy)); + set_local_wire( + self.gate.wire_claimed_element(copy), + get_local_wire(self.gate.wire_list_item(access_index, copy)), + ); - if i == access_index { - out_buffer.set_wire(equality_dummy_wire, F::ONE); - out_buffer.set_wire(index_matches_wire, F::ONE); - } else { - out_buffer.set_wire( - equality_dummy_wire, - (F::from_canonical_usize(i) - F::from_canonical_usize(access_index)).inverse(), - ); - out_buffer.set_wire(index_matches_wire, F::ZERO); - } + for i in 0..self.gate.bits { + let bit = F::from_bool(((access_index >> i) & 1) != 0); + set_local_wire(self.gate.wire_bit(i, copy), bit); } } } @@ -322,6 +340,7 @@ mod tests { /// Returns the local wires for a random access gate given the vectors, elements to compare, /// and indices. fn get_wires( + bits: usize, lists: Vec>, access_indices: Vec, claimed_elements: Vec, @@ -330,8 +349,7 @@ mod tests { let vec_size = lists[0].len(); let mut v = Vec::new(); - let mut equality_dummy_vals = Vec::new(); - let mut index_matches_vals = Vec::new(); + let mut bit_vals = Vec::new(); for copy in 0..num_copies { let access_index = access_indices[copy]; v.push(F::from_canonical_usize(access_index)); @@ -340,26 +358,17 @@ mod tests { v.push(lists[copy][j]); } - for i in 0..vec_size { - if i == access_index { - equality_dummy_vals.push(F::ONE); - index_matches_vals.push(F::ONE); - } else { - equality_dummy_vals.push( - (F::from_canonical_usize(i) - F::from_canonical_usize(access_index)) - .inverse(), - ); - index_matches_vals.push(F::ZERO); - } + for i in 0..bits { + bit_vals.push(F::from_bool(((access_index >> i) & 1) != 0)); } } - v.extend(equality_dummy_vals); - v.extend(index_matches_vals); + v.extend(bit_vals); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() } - let vec_size = 3; + let bits = 3; + let vec_size = 1 << bits; let num_copies = 4; let lists = (0..num_copies) .map(|_| F::rand_vec(vec_size)) @@ -368,7 +377,7 @@ mod tests { .map(|_| thread_rng().gen_range(0..vec_size)) .collect::>(); let gate = RandomAccessGate:: { - vec_size, + bits, num_copies, _phantom: PhantomData, }; @@ -380,13 +389,18 @@ mod tests { .collect(); let good_vars = EvaluationVars { local_constants: &[], - local_wires: &get_wires(lists.clone(), access_indices.clone(), good_claimed_elements), + local_wires: &get_wires( + bits, + lists.clone(), + access_indices.clone(), + good_claimed_elements, + ), public_inputs_hash: &HashOut::rand(), }; let bad_claimed_elements = F::rand_vec(4); let bad_vars = EvaluationVars { local_constants: &[], - local_wires: &get_wires(lists, access_indices, bad_claimed_elements), + local_wires: &get_wires(bits, lists, access_indices, bad_claimed_elements), public_inputs_hash: &HashOut::rand(), }; diff --git a/src/gates/reducing_extension.rs b/src/gates/reducing_extension.rs new file mode 100644 index 00000000..532b484f --- /dev/null +++ b/src/gates/reducing_extension.rs @@ -0,0 +1,222 @@ +use std::ops::Range; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::extension_field::FieldExtension; +use crate::field::field_types::RichField; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// Computes `sum alpha^i c_i` for a vector `c_i` of `num_coeffs` elements of the extension field. +#[derive(Debug, Clone)] +pub struct ReducingExtensionGate { + pub num_coeffs: usize, +} + +impl ReducingExtensionGate { + pub fn new(num_coeffs: usize) -> Self { + Self { num_coeffs } + } + + pub fn max_coeffs_len(num_wires: usize, num_routed_wires: usize) -> usize { + // `3*D` routed wires are used for the output, alpha and old accumulator. + // Need `num_coeffs*D` routed wires for coeffs, and `(num_coeffs-1)*D` wires for accumulators. + ((num_routed_wires - 3 * D) / D).min((num_wires - 2 * D) / (D * 2)) + } + + pub fn wires_output() -> Range { + 0..D + } + pub fn wires_alpha() -> Range { + D..2 * D + } + pub fn wires_old_acc() -> Range { + 2 * D..3 * D + } + const START_COEFFS: usize = 3 * D; + pub fn wires_coeff(i: usize) -> Range { + Self::START_COEFFS + i * D..Self::START_COEFFS + (i + 1) * D + } + fn start_accs(&self) -> usize { + Self::START_COEFFS + self.num_coeffs * D + } + fn wires_accs(&self, i: usize) -> Range { + debug_assert!(i < self.num_coeffs); + if i == self.num_coeffs - 1 { + // The last accumulator is the output. + return Self::wires_output(); + } + self.start_accs() + D * i..self.start_accs() + D * (i + 1) + } +} + +impl, const D: usize> Gate for ReducingExtensionGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let alpha = vars.get_local_ext_algebra(Self::wires_alpha()); + let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc()); + let coeffs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i))) + .collect::>(); + let accs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(self.wires_accs(i))) + .collect::>(); + + let mut constraints = Vec::with_capacity(>::num_constraints(self)); + let mut acc = old_acc; + for i in 0..self.num_coeffs { + constraints.push(acc * alpha + coeffs[i] - accs[i]); + acc = accs[i]; + } + + constraints + .into_iter() + .flat_map(|alg| alg.to_basefield_array()) + .collect() + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let alpha = vars.get_local_ext(Self::wires_alpha()); + let old_acc = vars.get_local_ext(Self::wires_old_acc()); + let coeffs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext(Self::wires_coeff(i))) + .collect::>(); + let accs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext(self.wires_accs(i))) + .collect::>(); + + let mut constraints = Vec::with_capacity(>::num_constraints(self)); + let mut acc = old_acc; + for i in 0..self.num_coeffs { + constraints.extend((acc * alpha + coeffs[i] - accs[i]).to_basefield_array()); + acc = accs[i]; + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let alpha = vars.get_local_ext_algebra(Self::wires_alpha()); + let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc()); + let coeffs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i))) + .collect::>(); + let accs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(self.wires_accs(i))) + .collect::>(); + + let mut constraints = Vec::with_capacity(>::num_constraints(self)); + let mut acc = old_acc; + for i in 0..self.num_coeffs { + let coeff = coeffs[i]; + let mut tmp = builder.mul_add_ext_algebra(acc, alpha, coeff); + tmp = builder.sub_ext_algebra(tmp, accs[i]); + constraints.push(tmp); + acc = accs[i]; + } + + constraints + .into_iter() + .flat_map(|alg| alg.to_ext_target_array()) + .collect() + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + vec![Box::new( + ReducingGenerator { + gate_index, + gate: self.clone(), + } + .adapter(), + )] + } + + fn num_wires(&self) -> usize { + 2 * D + 2 * D * self.num_coeffs + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 2 + } + + fn num_constraints(&self) -> usize { + D * self.num_coeffs + } +} + +#[derive(Debug)] +struct ReducingGenerator { + gate_index: usize, + gate: ReducingExtensionGate, +} + +impl, const D: usize> SimpleGenerator for ReducingGenerator { + fn dependencies(&self) -> Vec { + ReducingExtensionGate::::wires_alpha() + .chain(ReducingExtensionGate::::wires_old_acc()) + .chain((0..self.gate.num_coeffs).flat_map(ReducingExtensionGate::::wires_coeff)) + .map(|i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) + }; + + let alpha = local_extension(ReducingExtensionGate::::wires_alpha()); + let old_acc = local_extension(ReducingExtensionGate::::wires_old_acc()); + let coeffs = (0..self.gate.num_coeffs) + .map(|i| local_extension(ReducingExtensionGate::::wires_coeff(i))) + .collect::>(); + let accs = (0..self.gate.num_coeffs) + .map(|i| ExtensionTarget::from_range(self.gate_index, self.gate.wires_accs(i))) + .collect::>(); + + let mut acc = old_acc; + for i in 0..self.gate.num_coeffs { + let computed_acc = acc * alpha + coeffs[i]; + out_buffer.set_extension_target(accs[i], computed_acc); + acc = computed_acc; + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::reducing_extension::ReducingExtensionGate; + + #[test] + fn low_degree() { + test_low_degree::(ReducingExtensionGate::new(22)); + } + + #[test] + fn eval_fns() -> Result<()> { + test_eval_fns::(ReducingExtensionGate::new(22)) + } +} diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs new file mode 100644 index 00000000..fc4cd646 --- /dev/null +++ b/src/gates/subtraction_u32.rs @@ -0,0 +1,423 @@ +use std::marker::PhantomData; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, RichField}; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns +/// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked. +#[derive(Copy, Clone, Debug)] +pub struct U32SubtractionGate, const D: usize> { + pub num_ops: usize, + _phantom: PhantomData, +} + +impl, const D: usize> U32SubtractionGate { + pub fn new_from_config(config: &CircuitConfig) -> Self { + Self { + num_ops: Self::num_ops(config), + _phantom: PhantomData, + } + } + + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 5 + Self::num_limbs(); + let routed_wires_per_op = 5; + (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) + } + + pub fn wire_ith_input_x(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + 5 * i + } + pub fn wire_ith_input_y(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + 5 * i + 1 + } + pub fn wire_ith_input_borrow(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + 5 * i + 2 + } + + pub fn wire_ith_output_result(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + 5 * i + 3 + } + pub fn wire_ith_output_borrow(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + 5 * i + 4 + } + + pub fn limb_bits() -> usize { + 2 + } + // We have limbs for the 32 bits of `output_result`. + pub fn num_limbs() -> usize { + 32 / Self::limb_bits() + } + + pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); + debug_assert!(j < Self::num_limbs()); + 5 * self.num_ops + Self::num_limbs() * i + j + } +} + +impl, const D: usize> Gate for U32SubtractionGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; + + let result_initial = input_x - input_y - input_borrow; + let base = F::Extension::from_canonical_u64(1 << 32u64); + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; + + constraints.push(output_result - (result_initial + base * output_borrow)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = F::Extension::ZERO; + let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(product); + + combined_limbs = limb_base * combined_limbs + this_limb; + } + constraints.push(combined_limbs - output_result); + + // Range-check output_borrow to be one bit. + constraints.push(output_borrow * (F::Extension::ONE - output_borrow)); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; + + let result_initial = input_x - input_y - input_borrow; + let base = F::from_canonical_u64(1 << 32u64); + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; + + constraints.push(output_result - (result_initial + base * output_borrow)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = F::ZERO; + let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::from_canonical_usize(x)) + .product(); + constraints.push(product); + + combined_limbs = limb_base * combined_limbs + this_limb; + } + constraints.push(combined_limbs - output_result); + + // Range-check output_borrow to be one bit. + constraints.push(output_borrow * (F::ONE - output_borrow)); + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; + + let diff = builder.sub_extension(input_x, input_y); + let result_initial = builder.sub_extension(diff, input_borrow); + let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32u64)); + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; + + let computed_output = builder.mul_add_extension(base, output_borrow, result_initial); + constraints.push(builder.sub_extension(output_result, computed_output)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = builder.zero_extension(); + let limb_base = builder + .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let mut product = builder.one_extension(); + for x in 0..max_limb { + let x_target = + builder.constant_extension(F::Extension::from_canonical_usize(x)); + let diff = builder.sub_extension(this_limb, x_target); + product = builder.mul_extension(product, diff); + } + constraints.push(product); + + combined_limbs = builder.mul_add_extension(limb_base, combined_limbs, this_limb); + } + constraints.push(builder.sub_extension(combined_limbs, output_result)); + + // Range-check output_borrow to be one bit. + let one = builder.one_extension(); + let not_borrow = builder.sub_extension(one, output_borrow); + constraints.push(builder.mul_extension(output_borrow, not_borrow)); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + U32SubtractionGenerator { + gate: *self, + gate_index, + i, + _phantom: PhantomData, + } + .adapter(), + ); + g + }) + .collect() + } + + fn num_wires(&self) -> usize { + self.num_ops * (5 + Self::num_limbs()) + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 << Self::limb_bits() + } + + fn num_constraints(&self) -> usize { + self.num_ops * (3 + Self::num_limbs()) + } +} + +#[derive(Clone, Debug)] +struct U32SubtractionGenerator, const D: usize> { + gate: U32SubtractionGate, + gate_index: usize, + i: usize, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for U32SubtractionGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |input| Target::wire(self.gate_index, input); + + vec![ + local_target(self.gate.wire_ith_input_x(self.i)), + local_target(self.gate.wire_ith_input_y(self.i)), + local_target(self.gate.wire_ith_input_borrow(self.i)), + ] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let get_local_wire = |input| witness.get_wire(local_wire(input)); + + let input_x = get_local_wire(self.gate.wire_ith_input_x(self.i)); + let input_y = get_local_wire(self.gate.wire_ith_input_y(self.i)); + let input_borrow = get_local_wire(self.gate.wire_ith_input_borrow(self.i)); + + let result_initial = input_x - input_y - input_borrow; + let result_initial_u64 = result_initial.to_canonical_u64(); + let output_borrow = if result_initial_u64 > 1 << 32u64 { + F::ONE + } else { + F::ZERO + }; + + let base = F::from_canonical_u64(1 << 32u64); + let output_result = result_initial + base * output_borrow; + + let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i)); + let output_borrow_wire = local_wire(self.gate.wire_ith_output_borrow(self.i)); + + out_buffer.set_wire(output_result_wire, output_result); + out_buffer.set_wire(output_borrow_wire, output_borrow); + + let output_result_u64 = output_result.to_canonical_u64(); + + let num_limbs = U32SubtractionGate::::num_limbs(); + let limb_base = 1 << U32SubtractionGate::::limb_bits(); + let output_limbs: Vec<_> = (0..num_limbs) + .scan(output_result_u64, |acc, _| { + let tmp = *acc % limb_base; + *acc /= limb_base; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + for j in 0..num_limbs { + let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); + out_buffer.set_wire(wire, output_limbs[j]); + } + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use rand::Rng; + + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::{Field, PrimeField}; + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::subtraction_u32::U32SubtractionGate; + use crate::hash::hash_types::HashOut; + use crate::plonk::vars::EvaluationVars; + + #[test] + fn low_degree() { + test_low_degree::(U32SubtractionGate:: { + num_ops: 3, + _phantom: PhantomData, + }) + } + + #[test] + fn eval_fns() -> Result<()> { + test_eval_fns::(U32SubtractionGate:: { + num_ops: 3, + _phantom: PhantomData, + }) + } + + #[test] + fn test_gate_constraint() { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + const NUM_U32_SUBTRACTION_OPS: usize = 3; + + fn get_wires(inputs_x: Vec, inputs_y: Vec, borrows: Vec) -> Vec { + let mut v0 = Vec::new(); + let mut v1 = Vec::new(); + + let limb_bits = U32SubtractionGate::::limb_bits(); + let num_limbs = U32SubtractionGate::::num_limbs(); + let limb_base = 1 << limb_bits; + for c in 0..NUM_U32_SUBTRACTION_OPS { + let input_x = F::from_canonical_u64(inputs_x[c]); + let input_y = F::from_canonical_u64(inputs_y[c]); + let input_borrow = F::from_canonical_u64(borrows[c]); + + let result_initial = input_x - input_y - input_borrow; + let result_initial_u64 = result_initial.to_canonical_u64(); + let output_borrow = if result_initial_u64 > 1 << 32u64 { + F::ONE + } else { + F::ZERO + }; + + let base = F::from_canonical_u64(1 << 32u64); + let output_result = result_initial + base * output_borrow; + + let output_result_u64 = output_result.to_canonical_u64(); + + let mut output_limbs: Vec<_> = (0..num_limbs) + .scan(output_result_u64, |acc, _| { + let tmp = *acc % limb_base; + *acc /= limb_base; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + v0.push(input_x); + v0.push(input_y); + v0.push(input_borrow); + v0.push(output_result); + v0.push(output_borrow); + v1.append(&mut output_limbs); + } + + v0.iter() + .chain(v1.iter()) + .map(|&x| x.into()) + .collect::>() + } + + let mut rng = rand::thread_rng(); + let inputs_x = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + let inputs_y = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + let borrows = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| (rng.gen::() % 2) as u64) + .collect(); + + let gate = U32SubtractionGate:: { + num_ops: NUM_U32_SUBTRACTION_OPS, + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(inputs_x, inputs_y, borrows), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs b/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs index f122e0ef..6437818b 100644 --- a/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs +++ b/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs @@ -1,5 +1,7 @@ +#![allow(clippy::assertions_on_constants)] + use std::arch::aarch64::*; -use std::convert::TryInto; +use std::arch::asm; use static_assertions::const_assert; use unroll::unroll_for_loops; @@ -172,9 +174,7 @@ unsafe fn multiply(x: u64, y: u64) -> u64 { let xy_hi_lo_mul_epsilon = mul_epsilon(xy_hi); // add_with_wraparound is safe, as xy_hi_lo_mul_epsilon <= 0xfffffffe00000001 <= ORDER. - let res1 = add_with_wraparound(res0, xy_hi_lo_mul_epsilon); - - res1 + add_with_wraparound(res0, xy_hi_lo_mul_epsilon) } // ==================================== STANDALONE CONST LAYER ===================================== @@ -267,9 +267,7 @@ unsafe fn mds_reduce( // Multiply by EPSILON and accumulate. let res_unadj = vmlal_laneq_u32::<0>(res_lo, res_hi_hi, mds_consts0); let res_adj = vcgtq_u64(res_lo, res_unadj); - let res = vsraq_n_u64::<32>(res_unadj, res_adj); - - res + vsraq_n_u64::<32>(res_unadj, res_adj) } #[inline(always)] @@ -969,8 +967,7 @@ unsafe fn partial_round( #[inline(always)] unsafe fn full_round(state: [u64; 12], round_constants: &[u64; WIDTH]) -> [u64; 12] { let state = sbox_layer_full(state); - let state = mds_const_layers_full(state, round_constants); - state + mds_const_layers_full(state, round_constants) } #[inline] diff --git a/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs b/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs index 6f257f56..0467d1e5 100644 --- a/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs +++ b/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs @@ -1,5 +1,4 @@ use core::arch::x86_64::*; -use std::convert::TryInto; use std::mem::size_of; use static_assertions::const_assert; diff --git a/src/hash/hash_types.rs b/src/hash/hash_types.rs index 378431fc..7106a616 100644 --- a/src/hash/hash_types.rs +++ b/src/hash/hash_types.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use rand::Rng; use serde::{Deserialize, Deserializer, Serialize, Serializer}; diff --git a/src/hash/hashing.rs b/src/hash/hashing.rs index d2db82ca..44e7e831 100644 --- a/src/hash/hashing.rs +++ b/src/hash/hashing.rs @@ -1,7 +1,5 @@ //! Concrete instantiation of a hash function. -use std::convert::TryInto; - use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; use crate::hash::hash_types::{HashOut, HashOutTarget}; @@ -131,10 +129,8 @@ pub fn hash_n_to_m>( // Absorb all input chunks. for input_chunk in inputs.chunks(SPONGE_RATE) { - for i in 0..input_chunk.len() { - state[i] = input_chunk[i]; - } - state = P::permute(state); + state[..input_chunk.len()].copy_from_slice(input_chunk); + state = permute(state); } // Squeeze until we have the desired number of outputs. diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 29131f26..a2ec2a31 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use anyhow::{ensure, Result}; use serde::{Deserialize, Serialize}; @@ -55,6 +53,7 @@ pub(crate) fn verify_merkle_proof>( impl, const D: usize> CircuitBuilder { /// Verifies that the given leaf data is present at the given index in the Merkle tree with the /// given cap. The index is given by it's little-endian bits. + #[cfg(test)] pub(crate) fn verify_merkle_proof>( &mut self, leaf_data: Vec, @@ -94,7 +93,7 @@ impl, const D: usize> CircuitBuilder { proof: &MerkleProofTarget, ) { let zero = self.zero(); - let mut state: HashOutTarget = self.hash_or_noop::(leaf_data); + let mut state:HashOutTarget = self.hash_or_noop(leaf_data); for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { let mut perm_inputs = [zero; SPONGE_WIDTH]; @@ -116,7 +115,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn assert_hashes_equal(&mut self, x: HashOutTarget, y: HashOutTarget) { + pub fn connect_hashes(&mut self, x: HashOutTarget, y: HashOutTarget) { for i in 0..4 { self.connect(x.elements[i], y.elements[i]); } diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 6ac41fcb..1abda74b 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -1,8 +1,6 @@ //! Implementation of the Poseidon hash function, as described in //! https://eprint.iacr.org/2019/458.pdf -use std::convert::TryInto; - use unroll::unroll_for_loops; use crate::field::extension_field::target::ExtensionTarget; @@ -452,9 +450,10 @@ pub trait Poseidon: PrimeField { s0, ); for i in 1..WIDTH { - let t = ::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]; - let t = Self::from_canonical_u64(t); - d = builder.mul_const_add_extension(t, state[i], d); + let t = >::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]; + let t = Self::Extension::from_canonical_u64(t); + let t = builder.constant_extension(t); + d = builder.mul_add_extension(t, state[i], d); } let mut result = [builder.zero_extension(); WIDTH]; @@ -624,6 +623,7 @@ pub trait Poseidon: PrimeField { } } +#[cfg(test)] pub(crate) mod test_helpers { use crate::field::field_types::Field; use crate::hash::hashing::SPONGE_WIDTH; diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index 1e2cd293..b8801f79 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -1,7 +1,7 @@ +use crate::field::extension_field::target::ExtensionTarget; use std::convert::TryInto; use std::marker::PhantomData; -use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::RichField; use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; diff --git a/src/iop/generator.rs b/src/iop/generator.rs index e71ef8cd..6dfae5ef 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -1,9 +1,15 @@ use std::fmt::Debug; use std::marker::PhantomData; +use num::BigUint; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::Field; +use crate::field::field_types::{Field, RichField}; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::biguint::BigUintTarget; +use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -89,7 +95,7 @@ pub(crate) fn generate_partial_witness< assert_eq!( remaining_generators, 0, "{} generators weren't run", - remaining_generators + remaining_generators, ); witness @@ -156,6 +162,24 @@ impl GeneratedValues { self.target_values.push((target, value)) } + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)) + } + + pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + self.set_u32_target(target.get_limb(i), limbs[i]); + } + } + + pub fn set_nonnative_target(&mut self, target: NonNativeTarget, value: FF) { + self.set_biguint_target(target.value, value.to_biguint()) + } + pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) { ht.elements .iter() diff --git a/src/iop/target.rs b/src/iop/target.rs index 8d4cbcfb..de3e4911 100644 --- a/src/iop/target.rs +++ b/src/iop/target.rs @@ -41,6 +41,7 @@ impl Target { /// A `Target` which has already been constrained such that it can only be 0 or 1. #[derive(Copy, Clone, Debug)] +#[allow(clippy::manual_non_exhaustive)] pub struct BoolTarget { pub target: Target, /// This private field is here to force all instantiations to go through `new_unsafe`. diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 09e0a73e..a614ca4a 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -1,9 +1,14 @@ use std::collections::HashMap; -use std::convert::TryInto; + +use num::{BigUint, FromPrimitive, Zero}; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::{Field, RichField}; +use crate::field::field_types::Field; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::biguint::BigUintTarget; +use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::HashOutTarget; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; @@ -54,6 +59,24 @@ pub trait Witness { panic!("not a bool") } + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { + let mut result = BigUint::zero(); + + let limb_base = BigUint::from_u64(1 << 32u64).unwrap(); + for i in (0..target.num_limbs()).rev() { + let limb = target.get_limb(i); + result *= &limb_base; + result += self.get_target(limb.0).to_biguint(); + } + + result + } + + fn get_nonnative_target(&self, target: NonNativeTarget) -> FF { + let val = self.get_biguint_target(target.value); + FF::from_biguint(val) + } + fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { HashOut { elements: self.get_targets(&ht.elements).try_into().unwrap(), @@ -122,6 +145,16 @@ pub trait Witness { self.set_target(target.target, F::from_bool(value)) } + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)) + } + + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { + for (<, &l) in target.limbs.iter().zip(&value.to_u32_digits()) { + self.set_u32_target(lt, l); + } + } + fn set_wire(&mut self, wire: Wire, value: F) { self.set_target(Target::Wire(wire), value) } diff --git a/src/lib.rs b/src/lib.rs index c72f783c..291e6422 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,15 @@ -#![feature(asm)] -#![feature(destructuring_assignment)] +#![allow(incomplete_features)] +#![allow(const_evaluatable_unchecked)] +#![allow(clippy::new_without_default)] +#![allow(clippy::too_many_arguments)] +#![allow(clippy::len_without_is_empty)] +#![allow(clippy::needless_range_loop)] +#![feature(asm_sym)] #![feature(generic_const_exprs)] #![feature(specialization)] #![feature(stdsimd)] +pub mod curve; pub mod field; pub mod fri; pub mod gadgets; @@ -13,3 +19,11 @@ pub mod iop; pub mod plonk; pub mod polynomial; pub mod util; + +// Set up Jemalloc +#[cfg(not(target_env = "msvc"))] +use jemallocator::Jemalloc; + +#[cfg(not(target_env = "msvc"))] +#[global_allocator] +static GLOBAL: Jemalloc = Jemalloc; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index b9533c3e..9398dc97 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -1,6 +1,5 @@ use std::cmp::max; use std::collections::{BTreeMap, HashMap, HashSet}; -use std::convert::TryInto; use std::time::Instant; use log::{debug, info, Level}; @@ -12,14 +11,20 @@ use crate::field::fft::fft_root_table; use crate::field::field_types::Field; use crate::fri::commitment::PolynomialBatchCommitment; use crate::fri::{FriConfig, FriParams}; -use crate::gadgets::arithmetic_extension::ArithmeticOperation; -use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::gadgets::arithmetic::BaseArithmeticOperation; +use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gates::arithmetic_base::ArithmeticGate; +use crate::gates::arithmetic_extension::ArithmeticExtensionGate; +use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; +use crate::gates::multiplication_extension::MulExtensionGate; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; use crate::gates::random_access::RandomAccessGate; +use crate::gates::subtraction_u32::U32SubtractionGate; use crate::gates::switch::SwitchGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget}; use crate::iop::generator::{ @@ -35,7 +40,7 @@ use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::copy_constraint::CopyConstraint; use crate::plonk::permutation_argument::Forest; use crate::plonk::plonk_common::PlonkPolynomials; -use crate::polynomial::polynomial::PolynomialValues; +use crate::polynomial::PolynomialValues; use crate::util::context_tree::ContextTree; use crate::util::marking::{Markable, MarkedTargets}; use crate::util::partial_products::num_partial_products; @@ -71,24 +76,13 @@ pub struct CircuitBuilder, const D: usize> { constants_to_targets: HashMap, targets_to_constants: HashMap, + /// Memoized results of `arithmetic` calls. + pub(crate) base_arithmetic_results: HashMap, Target>, + /// Memoized results of `arithmetic_extension` calls. - pub(crate) arithmetic_results: HashMap, ExtensionTarget>, + pub(crate) arithmetic_results: HashMap, ExtensionTarget>, - /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using - /// these constants with gate index `g` and already using `i` arithmetic operations. - pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, - - /// A map `(c0, c1) -> (g, i)` from constants `vec_size` to an available arithmetic gate using - /// these constants with gate index `g` and already using `i` random accesses. - pub(crate) free_random_access: HashMap, - - // `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value - // chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies - // of switches - pub(crate) current_switch_gates: Vec, usize, usize)>>, - - /// An available `ConstantGate` instance, if any. - free_constant: Option<(usize, usize)>, + batched_gates: BatchedGates, } impl, const D: usize> CircuitBuilder { @@ -104,12 +98,10 @@ impl, const D: usize> CircuitBuilder { marked_targets: Vec::new(), generators: Vec::new(), constants_to_targets: HashMap::new(), + base_arithmetic_results: HashMap::new(), arithmetic_results: HashMap::new(), targets_to_constants: HashMap::new(), - free_arithmetic: HashMap::new(), - free_random_access: HashMap::new(), - current_switch_gates: Vec::new(), - free_constant: None, + batched_gates: BatchedGates::new(), }; builder.check_config(); builder @@ -216,6 +208,7 @@ impl, const D: usize> CircuitBuilder { gate_ref, constants, }); + index } @@ -260,6 +253,11 @@ impl, const D: usize> CircuitBuilder { self.connect(x, zero); } + pub fn assert_one(&mut self, x: Target) { + let one = self.one(); + self.connect(x, one); + } + pub fn add_generators(&mut self, generators: Vec>>) { self.generators.extend(generators); } @@ -313,26 +311,6 @@ impl, const D: usize> CircuitBuilder { target } - /// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a - /// new `ConstantGate` if needed. - fn constant_gate_instance(&mut self) -> (usize, usize) { - if self.free_constant.is_none() { - let num_consts = self.config.constant_gate_size; - // We will fill this `ConstantGate` with zero constants initially. - // These will be overwritten by `constant` as the gate instances are filled. - let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]); - self.free_constant = Some((gate, 0)); - } - - let (gate, instance) = self.free_constant.unwrap(); - if instance + 1 < self.config.constant_gate_size { - self.free_constant = Some((gate, instance + 1)); - } else { - self.free_constant = None; - } - (gate, instance) - } - pub fn constants(&mut self, constants: &[F]) -> Vec { constants.iter().map(|&c| self.constant(c)).collect() } @@ -345,6 +323,11 @@ impl, const D: usize> CircuitBuilder { } } + /// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits. + pub fn constant_u32(&mut self, c: u32) -> U32Target { + U32Target(self.constant(F::from_canonical_u32(c))) + } + /// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns /// its constant value. Otherwise, returns `None`. pub fn target_as_constant(&self, target: Target) -> Option { @@ -396,6 +379,20 @@ impl, const D: usize> CircuitBuilder { } } + /// The number of (base field) `arithmetic` operations that can be performed in a single gate. + pub(crate) fn num_base_arithmetic_ops_per_gate(&self) -> usize { + if self.config.use_base_arithmetic_gate { + ArithmeticGate::new_from_config(&self.config).num_ops + } else { + self.num_ext_arithmetic_ops_per_gate() + } + } + + /// The number of `arithmetic_extension` operations that can be performed in a single gate. + pub(crate) fn num_ext_arithmetic_ops_per_gate(&self) -> usize { + ArithmeticExtensionGate::::new_from_config(&self.config).num_ops + } + /// The number of polynomial values that will be revealed per opening, both for the "regular" /// polynomials and for the Z polynomials. Because calculating these values involves a recursive /// dependence (the amount of blinding depends on the degree, which depends on the blinding), @@ -566,76 +563,6 @@ impl, const D: usize> CircuitBuilder { ) } - /// Fill the remaining unused arithmetic operations with zeros, so that all - /// `ArithmeticExtensionGenerator` are run. - fn fill_arithmetic_gates(&mut self) { - let zero = self.zero_extension(); - let remaining_arithmetic_gates = self.free_arithmetic.values().copied().collect::>(); - for (gate, i) in remaining_arithmetic_gates { - for j in i..ArithmeticExtensionGate::::num_ops(&self.config) { - let wires_multiplicand_0 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_multiplicand_0(j), - ); - let wires_multiplicand_1 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_multiplicand_1(j), - ); - let wires_addend = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_addend(j), - ); - - self.connect_extension(zero, wires_multiplicand_0); - self.connect_extension(zero, wires_multiplicand_1); - self.connect_extension(zero, wires_addend); - } - } - } - - /// Fill the remaining unused random access operations with zeros, so that all - /// `RandomAccessGenerator`s are run. - fn fill_random_access_gates(&mut self) { - let zero = self.zero(); - for (vec_size, (_, i)) in self.free_random_access.clone() { - let max_copies = RandomAccessGate::::max_num_copies( - self.config.num_routed_wires, - self.config.num_wires, - vec_size, - ); - for _ in i..max_copies { - self.random_access(zero, zero, vec![zero; vec_size]); - } - } - } - - /// Fill the remaining unused switch gates with dummy values, so that all - /// `SwitchGenerator` are run. - fn fill_switch_gates(&mut self) { - let zero = self.zero(); - - for chunk_size in 1..=self.current_switch_gates.len() { - if let Some((gate, gate_index, mut copy)) = - self.current_switch_gates[chunk_size - 1].clone() - { - while copy < gate.num_copies { - for element in 0..chunk_size { - let wire_first_input = - Target::wire(gate_index, gate.wire_first_input(copy, element)); - let wire_second_input = - Target::wire(gate_index, gate.wire_second_input(copy, element)); - let wire_switch_bool = - Target::wire(gate_index, gate.wire_switch_bool(copy)); - self.connect(zero, wire_first_input); - self.connect(zero, wire_second_input); - self.connect(zero, wire_switch_bool); - } - copy += 1; - } - } - } - } - pub fn print_gate_counts(&self, min_delta: usize) { // Print gate counts for each context. self.context_log @@ -659,9 +586,7 @@ impl, const D: usize> CircuitBuilder { let mut timing = TimingTree::new("preprocess", Level::Trace); let start = Instant::now(); - self.fill_arithmetic_gates(); - self.fill_random_access_gates(); - self.fill_switch_gates(); + self.fill_batched_gates(); // Hash the public inputs, and route them to a `PublicInputGate` which will enforce that // those hash wires match the claimed public inputs. @@ -698,7 +623,7 @@ impl, const D: usize> CircuitBuilder { ..=1 << self.config.rate_bits) .min_by_key(|&q| num_partial_products(self.config.num_routed_wires, q).0 + q) .unwrap(); - info!("Quotient degree factor set to: {}.", quotient_degree_factor); + debug!("Quotient degree factor set to: {}.", quotient_degree_factor); let prefixed_gates = PrefixedGate::from_tree(gate_tree); let subgroup = F::two_adic_subgroup(degree_bits); @@ -710,7 +635,7 @@ impl, const D: usize> CircuitBuilder { // Precompute FFT roots. let max_fft_points = - 1 << degree_bits + max(self.config.rate_bits, log2_ceil(quotient_degree_factor)); + 1 << (degree_bits + max(self.config.rate_bits, log2_ceil(quotient_degree_factor))); let fft_root_table = fft_root_table(max_fft_points); let constants_sigmas_vecs = [constant_vecs, sigma_vecs.clone()].concat(); @@ -745,7 +670,7 @@ impl, const D: usize> CircuitBuilder { let watch_rep_index = forest.parents[watch_index]; generator_indices_by_watches .entry(watch_rep_index) - .or_insert(vec![]) + .or_insert_with(Vec::new) .push(i); } } @@ -801,7 +726,7 @@ impl, const D: usize> CircuitBuilder { circuit_digest, }; - info!("Building circuit took {}s", start.elapsed().as_secs_f32()); + debug!("Building circuit took {}s", start.elapsed().as_secs_f32()); CircuitData { prover_only, verifier_only, @@ -837,3 +762,386 @@ impl, const D: usize> CircuitBuilder { } } } + +/// Various gate types can contain multiple copies in a single Gate. This helper struct lets a +/// CircuitBuilder track such gates that are currently being "filled up." +pub struct BatchedGates, const D: usize> { + /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using + /// these constants with gate index `g` and already using `i` arithmetic operations. + pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, + pub(crate) free_base_arithmetic: HashMap<(F, F), (usize, usize)>, + + pub(crate) free_mul: HashMap, + + /// A map `b -> (g, i)` from `b` bits to an available random access gate of that size with gate + /// index `g` and already using `i` random accesses. + pub(crate) free_random_access: HashMap, + + /// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value + /// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies + /// of switches + pub(crate) current_switch_gates: Vec, usize, usize)>>, + + /// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one) + pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>, + + /// The `U32SubtractionGate` currently being filled (so new u32 subtraction operations will be added to this gate before creating a new one) + pub(crate) current_u32_subtraction_gate: Option<(usize, usize)>, + + /// An available `ConstantGate` instance, if any. + pub(crate) free_constant: Option<(usize, usize)>, +} + +impl, const D: usize> BatchedGates { + pub fn new() -> Self { + Self { + free_arithmetic: HashMap::new(), + free_base_arithmetic: HashMap::new(), + free_mul: HashMap::new(), + free_random_access: HashMap::new(), + current_switch_gates: Vec::new(), + current_u32_arithmetic_gate: None, + current_u32_subtraction_gate: None, + free_constant: None, + } + } +} + +impl, const D: usize> CircuitBuilder { + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub(crate) fn find_base_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_base_arithmetic + .get(&(const_0, const_1)) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + ArithmeticGate::new_from_config(&self.config), + vec![const_0, const_1], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < ArithmeticGate::num_ops(&self.config) - 1 { + self.batched_gates + .free_base_arithmetic + .insert((const_0, const_1), (gate, i + 1)); + } else { + self.batched_gates + .free_base_arithmetic + .remove(&(const_0, const_1)); + } + + (gate, i) + } + + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub(crate) fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_arithmetic + .get(&(const_0, const_1)) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + ArithmeticExtensionGate::new_from_config(&self.config), + vec![const_0, const_1], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < ArithmeticExtensionGate::::num_ops(&self.config) - 1 { + self.batched_gates + .free_arithmetic + .insert((const_0, const_1), (gate, i + 1)); + } else { + self.batched_gates + .free_arithmetic + .remove(&(const_0, const_1)); + } + + (gate, i) + } + + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub(crate) fn find_mul_gate(&mut self, const_0: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_mul + .get(&const_0) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + MulExtensionGate::new_from_config(&self.config), + vec![const_0], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < MulExtensionGate::::num_ops(&self.config) - 1 { + self.batched_gates.free_mul.insert(const_0, (gate, i + 1)); + } else { + self.batched_gates.free_mul.remove(&const_0); + } + + (gate, i) + } + + /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. + /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index + /// `g` and the gate's `i`-th random access is available. + pub(crate) fn find_random_access_gate(&mut self, bits: usize) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_random_access + .get(&bits) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + RandomAccessGate::new_from_config(&self.config, bits), + vec![], + ); + (gate, 0) + }); + + // Update `free_random_access` with new values. + if i + 1 < RandomAccessGate::::new_from_config(&self.config, bits).num_copies { + self.batched_gates + .free_random_access + .insert(bits, (gate, i + 1)); + } else { + self.batched_gates.free_random_access.remove(&bits); + } + + (gate, i) + } + + pub(crate) fn find_switch_gate( + &mut self, + chunk_size: usize, + ) -> (SwitchGate, usize, usize) { + if self.batched_gates.current_switch_gates.len() < chunk_size { + self.batched_gates.current_switch_gates.extend(vec![ + None; + chunk_size + - self + .batched_gates + .current_switch_gates + .len() + ]); + } + + let (gate, gate_index, next_copy) = + match self.batched_gates.current_switch_gates[chunk_size - 1].clone() { + None => { + let gate = SwitchGate::::new_from_config(&self.config, chunk_size); + let gate_index = self.add_gate(gate.clone(), vec![]); + (gate, gate_index, 0) + } + Some((gate, idx, next_copy)) => (gate, idx, next_copy), + }; + + let num_copies = gate.num_copies; + + if next_copy == num_copies - 1 { + self.batched_gates.current_switch_gates[chunk_size - 1] = None; + } else { + self.batched_gates.current_switch_gates[chunk_size - 1] = + Some((gate.clone(), gate_index, next_copy + 1)); + } + + (gate, gate_index, next_copy) + } + + pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) { + let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate { + None => { + let gate = U32ArithmeticGate::new_from_config(&self.config); + let gate_index = self.add_gate(gate, vec![]); + (gate_index, 0) + } + Some((gate_index, copy)) => (gate_index, copy), + }; + + if copy == U32ArithmeticGate::::num_ops(&self.config) - 1 { + self.batched_gates.current_u32_arithmetic_gate = None; + } else { + self.batched_gates.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); + } + + (gate_index, copy) + } + + pub(crate) fn find_u32_subtraction_gate(&mut self) -> (usize, usize) { + let (gate_index, copy) = match self.batched_gates.current_u32_subtraction_gate { + None => { + let gate = U32SubtractionGate::new_from_config(&self.config); + let gate_index = self.add_gate(gate, vec![]); + (gate_index, 0) + } + Some((gate_index, copy)) => (gate_index, copy), + }; + + if copy == U32SubtractionGate::::num_ops(&self.config) - 1 { + self.batched_gates.current_u32_subtraction_gate = None; + } else { + self.batched_gates.current_u32_subtraction_gate = Some((gate_index, copy + 1)); + } + + (gate_index, copy) + } + + /// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a + /// new `ConstantGate` if needed. + fn constant_gate_instance(&mut self) -> (usize, usize) { + if self.batched_gates.free_constant.is_none() { + let num_consts = self.config.constant_gate_size; + // We will fill this `ConstantGate` with zero constants initially. + // These will be overwritten by `constant` as the gate instances are filled. + let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]); + self.batched_gates.free_constant = Some((gate, 0)); + } + + let (gate, instance) = self.batched_gates.free_constant.unwrap(); + if instance + 1 < self.config.constant_gate_size { + self.batched_gates.free_constant = Some((gate, instance + 1)); + } else { + self.batched_gates.free_constant = None; + } + (gate, instance) + } + + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticGate` are run. + fn fill_base_arithmetic_gates(&mut self) { + let zero = self.zero(); + for ((c0, c1), (_gate, i)) in self.batched_gates.free_base_arithmetic.clone() { + for _ in i..ArithmeticGate::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_target(); + self.arithmetic(c0, c1, dummy, dummy, dummy); + self.connect(dummy, zero); + } + } + assert!(self.batched_gates.free_base_arithmetic.is_empty()); + } + + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticExtensionGenerator`s are run. + fn fill_arithmetic_gates(&mut self) { + let zero = self.zero_extension(); + for ((c0, c1), (_gate, i)) in self.batched_gates.free_arithmetic.clone() { + for _ in i..ArithmeticExtensionGate::::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_extension_target(); + self.arithmetic_extension(c0, c1, dummy, dummy, dummy); + self.connect_extension(dummy, zero); + } + } + assert!(self.batched_gates.free_arithmetic.is_empty()); + } + + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticExtensionGenerator`s are run. + fn fill_mul_gates(&mut self) { + let zero = self.zero_extension(); + for (c0, (_gate, i)) in self.batched_gates.free_mul.clone() { + for _ in i..MulExtensionGate::::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_extension_target(); + self.arithmetic_extension(c0, F::ZERO, dummy, dummy, zero); + self.connect_extension(dummy, zero); + } + } + assert!(self.batched_gates.free_mul.is_empty()); + } + + /// Fill the remaining unused random access operations with zeros, so that all + /// `RandomAccessGenerator`s are run. + fn fill_random_access_gates(&mut self) { + let zero = self.zero(); + for (bits, (_, i)) in self.batched_gates.free_random_access.clone() { + let max_copies = + RandomAccessGate::::new_from_config(&self.config, bits).num_copies; + for _ in i..max_copies { + self.random_access(zero, zero, vec![zero; 1 << bits]); + } + } + } + + /// Fill the remaining unused switch gates with dummy values, so that all + /// `SwitchGenerator`s are run. + fn fill_switch_gates(&mut self) { + let zero = self.zero(); + + for chunk_size in 1..=self.batched_gates.current_switch_gates.len() { + if let Some((gate, gate_index, mut copy)) = + self.batched_gates.current_switch_gates[chunk_size - 1].clone() + { + while copy < gate.num_copies { + for element in 0..chunk_size { + let wire_first_input = + Target::wire(gate_index, gate.wire_first_input(copy, element)); + let wire_second_input = + Target::wire(gate_index, gate.wire_second_input(copy, element)); + let wire_switch_bool = + Target::wire(gate_index, gate.wire_switch_bool(copy)); + self.connect(zero, wire_first_input); + self.connect(zero, wire_second_input); + self.connect(zero, wire_switch_bool); + } + copy += 1; + } + } + } + } + + /// Fill the remaining unused U32 arithmetic operations with zeros, so that all + /// `U32ArithmeticGenerator`s are run. + fn fill_u32_arithmetic_gates(&mut self) { + let zero = self.zero_u32(); + if let Some((_gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate { + for _ in copy..U32ArithmeticGate::::num_ops(&self.config) { + let dummy = self.add_virtual_u32_target(); + self.mul_add_u32(dummy, dummy, dummy); + self.connect_u32(dummy, zero); + } + } + } + + /// Fill the remaining unused U32 subtraction operations with zeros, so that all + /// `U32SubtractionGenerator`s are run. + fn fill_u32_subtraction_gates(&mut self) { + let zero = self.zero_u32(); + if let Some((_gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate { + for _i in copy..U32SubtractionGate::::num_ops(&self.config) { + let dummy = self.add_virtual_u32_target(); + self.sub_u32(dummy, dummy, dummy); + self.connect_u32(dummy, zero); + } + } + } + + fn fill_batched_gates(&mut self) { + self.fill_arithmetic_gates(); + self.fill_base_arithmetic_gates(); + self.fill_mul_gates(); + self.fill_random_access_gates(); + self.fill_switch_gates(); + self.fill_u32_arithmetic_gates(); + self.fill_u32_subtraction_gates(); + } +} diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index e8ff8797..ea90fb1b 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -16,6 +16,7 @@ use crate::iop::target::Target; use crate::iop::witness::PartialWitness; use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::proof::ProofWithPublicInputs; +use crate::plonk::proof::{CompressedProofWithPublicInputs, ProofWithPublicInputs}; use crate::plonk::prover::prove; use crate::plonk::verifier::verify; use crate::util::marking::MarkedTargets; @@ -26,6 +27,9 @@ pub struct CircuitConfig { pub num_wires: usize, pub num_routed_wires: usize, pub constant_gate_size: usize, + /// Whether to use a dedicated gate for base field arithmetic, rather than using a single gate + /// for both base field and extension field arithmetic. + pub use_base_arithmetic_gate: bool, pub security_bits: usize, pub rate_bits: usize, /// The number of challenge points to generate, for IOPs that have soundness errors of (roughly) @@ -45,30 +49,35 @@ impl Default for CircuitConfig { } impl CircuitConfig { + pub fn rate(&self) -> f64 { + 1.0 / ((1 << self.rate_bits) as f64) + } + pub fn num_advice_wires(&self) -> usize { self.num_wires - self.num_routed_wires } /// A typical recursion config, without zero-knowledge, targeting ~100 bit security. - pub(crate) fn standard_recursion_config() -> Self { + pub fn standard_recursion_config() -> Self { Self { - num_wires: 143, - num_routed_wires: 25, - constant_gate_size: 6, + num_wires: 135, + num_routed_wires: 80, + constant_gate_size: 5, + use_base_arithmetic_gate: true, security_bits: 100, rate_bits: 3, num_challenges: 2, zero_knowledge: false, - cap_height: 3, + cap_height: 4, fri_config: FriConfig { proof_of_work_bits: 16, - reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), + reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5), num_query_rounds: 28, }, } } - pub(crate) fn standard_recursion_zk_config() -> Self { + pub fn standard_recursion_zk_config() -> Self { CircuitConfig { zero_knowledge: true, ..Self::standard_recursion_config() @@ -96,6 +105,13 @@ impl, C: GenericConfig, const D: usize> CircuitData) -> Result<()> { verify(proof_with_pis, &self.verifier_only, &self.common) } + + pub fn verify_compressed( + &self, + compressed_proof_with_pis: CompressedProofWithPublicInputs, + ) -> Result<()> { + compressed_proof_with_pis.verify(&self.verifier_only, &self.common) + } } /// Circuit data required by the prover. This may be thought of as a proving key, although it @@ -132,6 +148,13 @@ impl, C: GenericConfig, const D: usize> VerifierCircu pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { verify(proof_with_pis, &self.verifier_only, &self.common) } + + pub fn verify_compressed( + &self, + compressed_proof_with_pis: CompressedProofWithPublicInputs, + ) -> Result<()> { + compressed_proof_with_pis.verify(&self.verifier_only, &self.common) + } } /// Circuit data required by the prover, but not the verifier. @@ -194,8 +217,8 @@ pub struct CommonCircuitData, C: GenericConfig, const /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. pub(crate) k_is: Vec, - /// The number of partial products needed to compute the `Z` polynomials and the number - /// of partial products needed to compute the final product. + /// The number of partial products needed to compute the `Z` polynomials and + /// the number of original elements consumed in `partial_products()`. pub(crate) num_partial_products: (usize, usize), /// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to @@ -228,11 +251,6 @@ impl, C: GenericConfig, const D: usize> CommonCircuit self.quotient_degree_factor * self.degree() } - pub fn total_constraints(&self) -> usize { - // 2 constraints for each Z check. - self.config.num_challenges * 2 + self.num_gate_constraints - } - /// Range of the constants polynomials in the `constants_sigmas_commitment`. pub fn constants_range(&self) -> Range { 0..self.num_constants diff --git a/src/plonk/get_challenges.rs b/src/plonk/get_challenges.rs index 65d67c9d..94c14765 100644 --- a/src/plonk/get_challenges.rs +++ b/src/plonk/get_challenges.rs @@ -11,7 +11,7 @@ use crate::plonk::proof::{ CompressedProof, CompressedProofWithPublicInputs, FriInferredElements, OpeningSet, Proof, ProofChallenges, ProofWithPublicInputs, }; -use crate::polynomial::polynomial::PolynomialCoeffs; +use crate::polynomial::PolynomialCoeffs; use crate::util::reverse_bits; fn get_challenges, C: GenericConfig, const D: usize>( diff --git a/src/plonk/permutation_argument.rs b/src/plonk/permutation_argument.rs index ee9474d7..ca3977ce 100644 --- a/src/plonk/permutation_argument.rs +++ b/src/plonk/permutation_argument.rs @@ -5,7 +5,7 @@ use rayon::prelude::*; use crate::field::field_types::Field; use crate::iop::target::Target; use crate::iop::wire::Wire; -use crate::polynomial::polynomial::PolynomialValues; +use crate::polynomial::PolynomialValues; /// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure. pub struct Forest { @@ -45,15 +45,23 @@ impl Forest { } /// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives. - pub fn find(&mut self, x_index: usize) -> usize { - let x_parent = self.parents[x_index]; - if x_parent != x_index { - let root_index = self.find(x_parent); - self.parents[x_index] = root_index; - root_index - } else { - x_index + pub fn find(&mut self, mut x_index: usize) -> usize { + // Note: We avoid recursion here since the chains can be long, causing stack overflows. + + // First, find the representative of the set containing `x_index`. + let mut representative = x_index; + while self.parents[representative] != representative { + representative = self.parents[representative]; } + + // Then, update each node in this chain to point directly to the representative. + while self.parents[x_index] != x_index { + let old_parent = self.parents[x_index]; + self.parents[x_index] = representative; + x_index = old_parent; + } + + representative } /// Merge two sets. diff --git a/src/plonk/plonk_common.rs b/src/plonk/plonk_common.rs index 6b84886d..5be13740 100644 --- a/src/plonk/plonk_common.rs +++ b/src/plonk/plonk_common.rs @@ -42,17 +42,6 @@ impl PlonkPolynomials { index: 3, blinding: true, }; - - #[cfg(test)] - pub fn polynomials(i: usize) -> PolynomialsIndexBlinding { - match i { - 0 => Self::CONSTANTS_SIGMAS, - 1 => Self::WIRES, - 2 => Self::ZS_PARTIAL_PRODUCTS, - 3 => Self::QUOTIENT, - _ => panic!("There are only 4 sets of polynomials in Plonk."), - } - } } /// Evaluate the polynomial which vanishes on any multiplicative subgroup of a given order `n`. diff --git a/src/plonk/proof.rs b/src/plonk/proof.rs index 72c13f50..db6a1e2e 100644 --- a/src/plonk/proof.rs +++ b/src/plonk/proof.rs @@ -164,12 +164,12 @@ impl, C: GenericConfig, const D: usize> ) -> anyhow::Result> { let challenges = self.get_challenges(common_data)?; let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); - let compressed_proof = + let decompressed_proof = self.proof .decompress(&challenges, fri_inferred_elements, common_data); Ok(ProofWithPublicInputs { public_inputs: self.public_inputs, - proof: compressed_proof, + proof: decompressed_proof, }) } @@ -180,13 +180,13 @@ impl, C: GenericConfig, const D: usize> ) -> anyhow::Result<()> { let challenges = self.get_challenges(common_data)?; let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); - let compressed_proof = + let decompressed_proof = self.proof .decompress(&challenges, fri_inferred_elements, common_data); verify_with_challenges( ProofWithPublicInputs { public_inputs: self.public_inputs, - proof: compressed_proof, + proof: decompressed_proof, }, challenges, verifier_data, @@ -312,6 +312,7 @@ mod tests { use crate::field::field_types::Field; use crate::fri::reduction_strategies::FriReductionStrategy; + use crate::gates::noop::NoopGate; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; @@ -340,6 +341,9 @@ mod tests { let zt = builder.constant(z); let comp_zt = builder.mul(xt, yt); builder.connect(zt, comp_zt); + for _ in 0..100 { + builder.add_gate(NoopGate, vec![]); + } let data = builder.build::(); let proof = data.prove(pw)?; verify(proof.clone(), &data.verifier_only, &data.common)?; @@ -350,6 +354,6 @@ mod tests { assert_eq!(proof, decompressed_compressed_proof); verify(proof, &data.verifier_only, &data.common)?; - compressed_proof.verify(&data.verifier_only, &data.common) + data.verify_compressed(compressed_proof) } } diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index f2a85214..8f7cf3cd 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -1,3 +1,5 @@ +use std::mem::swap; + use anyhow::Result; use rayon::prelude::*; @@ -13,9 +15,9 @@ use crate::plonk::plonk_common::ZeroPolyOnCoset; use crate::plonk::proof::{Proof, ProofWithPublicInputs}; use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch; use crate::plonk::vars::EvaluationVarsBase; -use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; -use crate::util::partial_products::partial_products; +use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_products}; use crate::util::timing::TimingTree; use crate::util::{log2_ceil, transpose}; @@ -89,28 +91,22 @@ pub(crate) fn prove, C: GenericConfig, const D: usize common_data.quotient_degree_factor < common_data.config.num_routed_wires, "When the number of routed wires is smaller that the degree, we should change the logic to avoid computing partial products." ); - let mut partial_products = timed!( + let mut partial_products_and_zs = timed!( timing, "compute partial products", all_wires_permutation_partial_products(&witness, &betas, &gammas, prover_data, common_data) ); - let plonk_z_vecs = timed!( - timing, - "compute Z's", - compute_zs(&partial_products, common_data) - ); + // Z is expected at the front of our batch; see `zs_range` and `partial_products_range`. + let plonk_z_vecs = partial_products_and_zs + .iter_mut() + .map(|partial_products_and_z| partial_products_and_z.pop().unwrap()) + .collect(); + let zs_partial_products = [plonk_z_vecs, partial_products_and_zs.concat()].concat(); - // The first polynomial in `partial_products` represent the final product used in the - // computation of `Z`. It isn't needed anymore so we discard it. - partial_products.iter_mut().for_each(|part| { - part.remove(0); - }); - - let zs_partial_products = [plonk_z_vecs, partial_products.concat()].concat(); - let zs_partial_products_commitment = timed!( + let partial_products_and_zs_commitment = timed!( timing, - "commit to Z's", + "commit to partial products and Z's", PolynomialBatchCommitment::from_values( zs_partial_products, config.rate_bits, @@ -121,7 +117,7 @@ pub(crate) fn prove, C: GenericConfig, const D: usize ) ); - challenger.observe_cap(&zs_partial_products_commitment.merkle_tree.cap); + challenger.observe_cap(&partial_products_and_zs_commitment.merkle_tree.cap); let alphas = challenger.get_n_challenges(num_challenges); @@ -133,7 +129,7 @@ pub(crate) fn prove, C: GenericConfig, const D: usize prover_data, &public_inputs_hash, &wires_commitment, - &zs_partial_products_commitment, + &partial_products_and_zs_commitment, &betas, &gammas, &alphas, @@ -148,7 +144,6 @@ pub(crate) fn prove, C: GenericConfig, const D: usize .into_par_iter() .flat_map(|mut quotient_poly| { quotient_poly.trim(); - // TODO: Return Result instead of panicking. quotient_poly.pad(quotient_degree).expect( "Quotient has failed, the vanishing polynomial is not divisible by `Z_H", ); @@ -182,7 +177,7 @@ pub(crate) fn prove, C: GenericConfig, const D: usize &[ &prover_data.constants_sigmas_commitment, &wires_commitment, - &zs_partial_products_commitment, + &partial_products_and_zs_commitment, "ient_polys_commitment, ], zeta, @@ -194,7 +189,7 @@ pub(crate) fn prove, C: GenericConfig, const D: usize let proof = Proof { wires_cap: wires_commitment.merkle_tree.cap, - plonk_zs_partial_products_cap: zs_partial_products_commitment.merkle_tree.cap, + plonk_zs_partial_products_cap: partial_products_and_zs_commitment.merkle_tree.cap, quotient_polys_cap: quotient_polys_commitment.merkle_tree.cap, openings, opening_proof, @@ -219,7 +214,7 @@ fn all_wires_permutation_partial_products< ) -> Vec>> { (0..common_data.config.num_challenges) .map(|i| { - wires_permutation_partial_products( + wires_permutation_partial_products_and_zs( witness, betas[i], gammas[i], @@ -233,7 +228,7 @@ fn all_wires_permutation_partial_products< /// Compute the partial products used in the `Z` polynomial. /// Returns the polynomials interpolating `partial_products(f / g)` /// where `f, g` are the products in the definition of `Z`: `Z(g^i) = f / g`. -fn wires_permutation_partial_products< +fn wires_permutation_partial_products_and_zs< F: Extendable, C: GenericConfig, const D: usize, @@ -247,7 +242,8 @@ fn wires_permutation_partial_products< let degree = common_data.quotient_degree_factor; let subgroup = &prover_data.subgroup; let k_is = &common_data.k_is; - let values = subgroup + let (num_prods, _final_num_prod) = common_data.num_partial_products; + let all_quotient_chunk_products = subgroup .par_iter() .enumerate() .map(|(i, &x)| { @@ -271,49 +267,26 @@ fn wires_permutation_partial_products< .map(|(num, den_inv)| num * den_inv) .collect::>(); - let quotient_partials = partial_products("ient_values, degree); - - // This is the final product for the quotient. - let quotient = quotient_partials - [common_data.num_partial_products.0 - common_data.num_partial_products.1..] - .iter() - .copied() - .product(); - - // We add the quotient at the beginning of the vector to reuse them later in the computation of `Z`. - [vec![quotient], quotient_partials].concat() + quotient_chunk_products("ient_values, degree) }) .collect::>(); - transpose(&values) + let mut z_x = F::ONE; + let mut all_partial_products_and_zs = Vec::new(); + for quotient_chunk_products in all_quotient_chunk_products { + let mut partial_products_and_z_gx = + partial_products_and_z_gx(z_x, "ient_chunk_products); + // The last term is Z(gx), but we replace it with Z(x), otherwise Z would end up shifted. + swap(&mut z_x, &mut partial_products_and_z_gx[num_prods]); + all_partial_products_and_zs.push(partial_products_and_z_gx); + } + + transpose(&all_partial_products_and_zs) .into_par_iter() .map(PolynomialValues::new) .collect() } -fn compute_zs, C: GenericConfig, const D: usize>( - partial_products: &[Vec>], - common_data: &CommonCircuitData, -) -> Vec> { - (0..common_data.config.num_challenges) - .map(|i| compute_z(&partial_products[i], common_data)) - .collect() -} - -/// Compute the `Z` polynomial by reusing the computations done in `wires_permutation_partial_products`. -fn compute_z, C: GenericConfig, const D: usize>( - partial_products: &[PolynomialValues], - common_data: &CommonCircuitData, -) -> PolynomialValues { - let mut plonk_z_points = vec![F::ONE]; - for i in 1..common_data.degree() { - let quotient = partial_products[0].values[i - 1]; - let last = *plonk_z_points.last().unwrap(); - plonk_z_points.push(last * quotient); - } - plonk_z_points.into() -} - const BATCH_SIZE: usize = 32; fn compute_quotient_polys<'a, F: Extendable, C: GenericConfig, const D: usize>( diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index ab43d256..64c2b3fd 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -133,6 +133,7 @@ mod tests { use crate::fri::reduction_strategies::FriReductionStrategy; use crate::fri::FriConfig; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; + use crate::gates::noop::NoopGate; use crate::hash::merkle_proofs::MerkleProofTarget; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_data::VerifierOnlyCircuitData; @@ -369,9 +370,8 @@ mod tests { type F = >::F; let config = CircuitConfig::standard_recursion_config(); - let (proof, vd, cd) = dummy_proof::(&config, 8_000)?; - let (proof, _vd, cd) = - recursive_proof::(proof, vd, cd, &config, &config, true, true)?; + let (proof, vd, cd) = dummy_proof::(&config, 4_000)?; + let (proof, _vd, cd) = recursive_proof::(proof, vd, cd, &config, &config, None, true, true)?; test_serialization(&proof, &cd)?; Ok(()) @@ -388,11 +388,14 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); - let (proof, vd, cd) = dummy_proof::(&config, 8_000)?; + // Start with a degree 2^14 proof, then shrink it to 2^13, then to 2^12. + let (proof, vd, cd) = dummy_proof::(&config, 16_000)?; + assert_eq!(cd.degree_bits, 14); let (proof, vd, cd) = - recursive_proof::(proof, vd, cd, &config, &config, false, false)?; - let (proof, _vd, cd) = - recursive_proof::(proof, vd, cd, &config, &config, true, true)?; + recursive_proof::(proof, vd, cd, &config, &config, Some(13), false, false)?; + assert_eq!(cd.degree_bits, 13); + let (proof, _vd, cd) = recursive_proof::(proof, vd, cd, &config, &config, None, true, true)?; + assert_eq!(cd.degree_bits, 12); test_serialization(&proof, &cd)?; @@ -412,29 +415,29 @@ mod tests { let standard_config = CircuitConfig::standard_recursion_config(); - // A dummy proof with degree 2^13. - let (proof, vd, cd) = dummy_proof::(&standard_config, 8_000)?; - assert_eq!(cd.degree_bits, 13); + // An initial dummy proof. + let (proof, vd, cd) = dummy_proof::(&standard_config, 4_000)?; + assert_eq!(cd.degree_bits, 12); - // A standard recursive proof with degree 2^13. + // A standard recursive proof. let (proof, vd, cd) = recursive_proof( proof, vd, cd, &standard_config, &standard_config, + None, false, false, )?; - assert_eq!(cd.degree_bits, 13); + assert_eq!(cd.degree_bits, 12); - // A high-rate recursive proof with degree 2^13, designed to be verifiable with 2^12 - // gates and 48 routed wires. + // A high-rate recursive proof, designed to be verifiable with fewer routed wires. let high_rate_config = CircuitConfig { - rate_bits: 5, + rate_bits: 7, fri_config: FriConfig { - proof_of_work_bits: 20, - num_query_rounds: 16, + proof_of_work_bits: 16, + num_query_rounds: 12, ..standard_config.fri_config.clone() }, ..standard_config @@ -445,54 +448,35 @@ mod tests { cd, &standard_config, &high_rate_config, - true, - true, - )?; - assert_eq!(cd.degree_bits, 13); - - // A higher-rate recursive proof with degree 2^12, designed to be verifiable with 2^12 - // gates and 28 routed wires. - let higher_rate_more_routing_config = CircuitConfig { - rate_bits: 7, - num_routed_wires: 48, - fri_config: FriConfig { - proof_of_work_bits: 23, - num_query_rounds: 11, - ..standard_config.fri_config.clone() - }, - ..high_rate_config.clone() - }; - let (proof, vd, cd) = recursive_proof::( - proof, - vd, - cd, - &high_rate_config, - &higher_rate_more_routing_config, + None, true, true, )?; assert_eq!(cd.degree_bits, 12); - // A final proof of degree 2^12, optimized for size. + // A final proof, optimized for size. let final_config = CircuitConfig { cap_height: 0, - num_routed_wires: 32, + rate_bits: 8, + num_routed_wires: 37, fri_config: FriConfig { + proof_of_work_bits: 20, reduction_strategy: FriReductionStrategy::MinSize(None), - ..higher_rate_more_routing_config.fri_config.clone() + num_query_rounds: 10, }, - ..higher_rate_more_routing_config + ..high_rate_config }; let (proof, _vd, cd) = recursive_proof::( proof, vd, cd, - &higher_rate_more_routing_config, + &high_rate_config, &final_config, + None, true, true, )?; - assert_eq!(cd.degree_bits, 12); + assert_eq!(cd.degree_bits, 12, "final proof too large"); test_serialization(&proof, &cd)?; @@ -509,16 +493,12 @@ mod tests { CommonCircuitData, )> { let mut builder = CircuitBuilder::::new(config.clone()); - let input = builder.add_virtual_target(); - for i in 0..num_dummy_gates { - // Use unique constants to force a new `ArithmeticGate`. - let i_f = F::from_canonical_u64(i); - builder.arithmetic(i_f, i_f, input, input, input); + for _ in 0..num_dummy_gates { + builder.add_gate(NoopGate, vec![]); } let data = builder.build::(); - let mut inputs = PartialWitness::new(); - inputs.set_target(input, F::ZERO); + let inputs = PartialWitness::new(); let proof = data.prove(inputs)?; data.verify(proof.clone())?; @@ -536,6 +516,7 @@ mod tests { inner_cd: CommonCircuitData, inner_config: &CircuitConfig, config: &CircuitConfig, + min_degree_bits: Option, print_gate_counts: bool, print_timing: bool, ) -> Result<( @@ -556,12 +537,22 @@ mod tests { &inner_vd.constants_sigmas_cap, ); - builder.add_recursive_verifier(pt, &inner_config, &inner_data, &inner_cd); + builder.add_recursive_verifier(pt, inner_config, &inner_data, &inner_cd); if print_gate_counts { builder.print_gate_counts(0); } + if let Some(min_degree_bits) = min_degree_bits { + // We don't want to pad all the way up to 2^min_degree_bits, as the builder will add a + // few special gates afterward. So just pad to 2^(min_degree_bits - 1) + 1. Then the + // builder will pad to the next power of two, 2^min_degree_bits. + let min_gates = (1 << (min_degree_bits - 1)) + 1; + for _ in builder.num_gates()..min_gates { + builder.add_gate(NoopGate, vec![]); + } + } + let data = builder.build::(); let mut timing = TimingTree::new("prove", Level::Debug); @@ -582,12 +573,12 @@ mod tests { ) -> Result<()> { let proof_bytes = proof.to_bytes()?; info!("Proof length: {} bytes", proof_bytes.len()); - let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, &cd)?; + let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, cd)?; assert_eq!(proof, &proof_from_bytes); let now = std::time::Instant::now(); - let compressed_proof = proof.clone().compress(&cd)?; - let decompressed_compressed_proof = compressed_proof.clone().decompress(&cd)?; + let compressed_proof = proof.clone().compress(cd)?; + let decompressed_compressed_proof = compressed_proof.clone().decompress(cd)?; info!("{:.4}s to compress proof", now.elapsed().as_secs_f64()); assert_eq!(proof, &decompressed_compressed_proof); @@ -597,7 +588,7 @@ mod tests { compressed_proof_bytes.len() ); let compressed_proof_from_bytes = - CompressedProofWithPublicInputs::from_bytes(compressed_proof_bytes, &cd)?; + CompressedProofWithPublicInputs::from_bytes(compressed_proof_bytes, cd)?; assert_eq!(compressed_proof, compressed_proof_from_bytes); Ok(()) diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index a68fe5c5..025fd0bd 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -29,7 +29,7 @@ pub(crate) fn eval_vanishing_poly, C: GenericConfig, alphas: &[F], ) -> Vec { let max_degree = common_data.quotient_degree_factor; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let constraint_terms = evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars); @@ -38,14 +38,12 @@ pub(crate) fn eval_vanishing_poly, C: GenericConfig, let mut vanishing_z_1_terms = Vec::new(); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - // The Z(x) f'(x) - g'(x) Z(g x) terms. - let mut vanishing_v_shift_terms = Vec::new(); let l1_x = plonk_common::eval_l_1(common_data.degree(), x); for i in 0..common_data.config.num_challenges { let z_x = local_zs[i]; - let z_gz = next_zs[i]; + let z_gx = next_zs[i]; vanishing_z_1_terms.push(l1_x * (z_x - F::Extension::ONE)); let numerator_values = (0..common_data.config.num_routed_wires) @@ -63,37 +61,24 @@ pub(crate) fn eval_vanishing_poly, C: GenericConfig, wire_value + s_sigma.scalar_mul(betas[i]) + gammas[i].into() }) .collect::>(); - let quotient_values = (0..common_data.config.num_routed_wires) - .map(|j| numerator_values[j] / denominator_values[j]) - .collect::>(); // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the quotient partial products. - let mut partial_product_check = - check_partial_products("ient_values, current_partial_products, max_degree); - // The first checks are of the form `q - n/d` which is a rational function not a polynomial. - // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - denominator_values - .chunks(max_degree) - .zip(partial_product_check.iter_mut()) - .for_each(|(d, q)| { - *q *= d.iter().copied().product(); - }); - vanishing_partial_products_terms.extend(partial_product_check); - - // The quotient final product is the product of the last `final_num_prod` elements. - let quotient: F::Extension = current_partial_products[num_prods - final_num_prod..] - .iter() - .copied() - .product(); - vanishing_v_shift_terms.push(quotient * z_x - z_gz); + let partial_product_checks = check_partial_products( + &numerator_values, + &denominator_values, + current_partial_products, + z_x, + z_gx, + max_degree, + ); + vanishing_partial_products_terms.extend(partial_product_checks); } let vanishing_terms = [ vanishing_z_1_terms, vanishing_partial_products_terms, - vanishing_v_shift_terms, constraint_terms, ] .concat(); @@ -130,7 +115,7 @@ pub(crate) fn eval_vanishing_poly_base_batch< assert_eq!(s_sigmas_batch.len(), n); let max_degree = common_data.quotient_degree_factor; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let num_gate_constraints = common_data.num_gate_constraints; @@ -143,14 +128,11 @@ pub(crate) fn eval_vanishing_poly_base_batch< let mut numerator_values = Vec::with_capacity(num_routed_wires); let mut denominator_values = Vec::with_capacity(num_routed_wires); - let mut quotient_values = Vec::with_capacity(num_routed_wires); // The L_1(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - // The Z(x) f'(x) - g'(x) Z(g x) terms. - let mut vanishing_v_shift_terms = Vec::with_capacity(num_challenges); let mut res_batch: Vec> = Vec::with_capacity(n); for k in 0..n { @@ -168,7 +150,7 @@ pub(crate) fn eval_vanishing_poly_base_batch< let l1_x = z_h_on_coset.eval_l1(index, x); for i in 0..num_challenges { let z_x = local_zs[i]; - let z_gz = next_zs[i]; + let z_gx = next_zs[i]; vanishing_z_1_terms.push(l1_x * z_x.sub_one()); numerator_values.extend((0..num_routed_wires).map(|j| { @@ -182,49 +164,33 @@ pub(crate) fn eval_vanishing_poly_base_batch< let s_sigma = s_sigmas[j]; wire_value + betas[i] * s_sigma + gammas[i] })); - let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values); - quotient_values.extend( - (0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]), - ); // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the numerator partial products. - let mut partial_product_check = - check_partial_products("ient_values, current_partial_products, max_degree); - // The first checks are of the form `q - n/d` which is a rational function not a polynomial. - // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - denominator_values - .chunks(max_degree) - .zip(partial_product_check.iter_mut()) - .for_each(|(d, q)| { - *q *= d.iter().copied().product(); - }); - vanishing_partial_products_terms.extend(partial_product_check); - - // The quotient final product is the product of the last `final_num_prod` elements. - let quotient: F = current_partial_products[num_prods - final_num_prod..] - .iter() - .copied() - .product(); - vanishing_v_shift_terms.push(quotient * z_x - z_gz); + let partial_product_checks = check_partial_products( + &numerator_values, + &denominator_values, + current_partial_products, + z_x, + z_gx, + max_degree, + ); + vanishing_partial_products_terms.extend(partial_product_checks); numerator_values.clear(); denominator_values.clear(); - quotient_values.clear(); } let vanishing_terms = vanishing_z_1_terms .iter() .chain(vanishing_partial_products_terms.iter()) - .chain(vanishing_v_shift_terms.iter()) .chain(constraint_terms); let res = plonk_common::reduce_with_powers_multi(vanishing_terms, alphas); res_batch.push(res); vanishing_z_1_terms.clear(); vanishing_partial_products_terms.clear(); - vanishing_v_shift_terms.clear(); } res_batch } @@ -334,7 +300,7 @@ pub(crate) fn eval_vanishing_poly_recursively< alphas: &[Target], ) -> Vec> { let max_degree = common_data.quotient_degree_factor; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let constraint_terms = with_context!( builder, @@ -351,8 +317,6 @@ pub(crate) fn eval_vanishing_poly_recursively< let mut vanishing_z_1_terms = Vec::new(); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - // The Z(x) f'(x) - g'(x) Z(g x) terms. - let mut vanishing_v_shift_terms = Vec::new(); let l1_x = eval_l_1_recursively(builder, common_data.degree(), x, x_pow_deg); @@ -365,14 +329,13 @@ pub(crate) fn eval_vanishing_poly_recursively< for i in 0..common_data.config.num_challenges { let z_x = local_zs[i]; - let z_gz = next_zs[i]; + let z_gx = next_zs[i]; // L_1(x) Z(x) = 0. vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x)); let mut numerator_values = Vec::new(); let mut denominator_values = Vec::new(); - let mut quotient_values = Vec::new(); for j in 0..common_data.config.num_routed_wires { let wire_value = vars.local_wires[j]; @@ -385,44 +348,28 @@ pub(crate) fn eval_vanishing_poly_recursively< let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma); let denominator = builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma); - let quotient = builder.div_extension(numerator, denominator); - numerator_values.push(numerator); denominator_values.push(denominator); - quotient_values.push(quotient); } // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the quotient partial products. - let mut partial_product_check = check_partial_products_recursively( + let partial_product_checks = check_partial_products_recursively( builder, - "ient_values, + &numerator_values, + &denominator_values, current_partial_products, + z_x, + z_gx, max_degree, ); - // The first checks are of the form `q - n/d` which is a rational function not a polynomial. - // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - denominator_values - .chunks(max_degree) - .zip(partial_product_check.iter_mut()) - .for_each(|(d, q)| { - let mut v = d.to_vec(); - v.push(*q); - *q = builder.mul_many_extension(&v); - }); - vanishing_partial_products_terms.extend(partial_product_check); - - // The quotient final product is the product of the last `final_num_prod` elements. - let quotient = - builder.mul_many_extension(¤t_partial_products[num_prods - final_num_prod..]); - vanishing_v_shift_terms.push(builder.mul_sub_extension(quotient, z_x, z_gz)); + vanishing_partial_products_terms.extend(partial_product_checks); } let vanishing_terms = [ vanishing_z_1_terms, vanishing_partial_products_terms, - vanishing_v_shift_terms, constraint_terms, ] .concat(); diff --git a/src/plonk/vars.rs b/src/plonk/vars.rs index 110aa689..b643b7b7 100644 --- a/src/plonk/vars.rs +++ b/src/plonk/vars.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::ops::Range; use crate::field::extension_field::algebra::ExtensionAlgebra; diff --git a/src/polynomial/division.rs b/src/polynomial/division.rs index 6ac38676..671b7715 100644 --- a/src/polynomial/division.rs +++ b/src/polynomial/division.rs @@ -1,7 +1,6 @@ -use crate::field::fft::{fft, ifft}; use crate::field::field_types::Field; -use crate::polynomial::polynomial::PolynomialCoeffs; -use crate::util::{log2_ceil, log2_strict}; +use crate::polynomial::PolynomialCoeffs; +use crate::util::log2_ceil; impl PolynomialCoeffs { /// Polynomial division. @@ -67,63 +66,6 @@ impl PolynomialCoeffs { } } - /// Takes a polynomial `a` in coefficient form, and divides it by `Z_H = X^n - 1`. - /// - /// This assumes `Z_H | a`, otherwise result is meaningless. - pub(crate) fn divide_by_z_h(&self, n: usize) -> PolynomialCoeffs { - let mut a = self.clone(); - - // TODO: Is this special case needed? - if a.coeffs.iter().all(|p| *p == F::ZERO) { - return a; - } - - let g = F::MULTIPLICATIVE_GROUP_GENERATOR; - let mut g_pow = F::ONE; - // Multiply the i-th coefficient of `a` by `g^i`. Then `new_a(w^j) = old_a(g.w^j)`. - a.coeffs.iter_mut().for_each(|x| { - *x *= g_pow; - g_pow *= g; - }); - - let root = F::primitive_root_of_unity(log2_strict(a.len())); - // Equals to the evaluation of `a` on `{g.w^i}`. - let mut a_eval = fft(&a); - // Compute the denominators `1/(g^n.w^(n*i) - 1)` using batch inversion. - let denominator_g = g.exp_u64(n as u64); - let root_n = root.exp_u64(n as u64); - let mut root_pow = F::ONE; - let denominators = (0..a_eval.len()) - .map(|i| { - if i != 0 { - root_pow *= root_n; - } - denominator_g * root_pow - F::ONE - }) - .collect::>(); - let denominators_inv = F::batch_multiplicative_inverse(&denominators); - // Divide every element of `a_eval` by the corresponding denominator. - // Then, `a_eval` is the evaluation of `a/Z_H` on `{g.w^i}`. - a_eval - .values - .iter_mut() - .zip(denominators_inv.iter()) - .for_each(|(x, &d)| { - *x *= d; - }); - // `p` is the interpolating polynomial of `a_eval` on `{w^i}`. - let mut p = ifft(&a_eval); - // We need to scale it by `g^(-i)` to get the interpolating polynomial of `a_eval` on `{g.w^i}`, - // a.k.a `a/Z_H`. - let g_inv = g.inverse(); - let mut g_inv_pow = F::ONE; - p.coeffs.iter_mut().for_each(|x| { - *x *= g_inv_pow; - g_inv_pow *= g_inv; - }); - p - } - /// Let `self=p(X)`, this returns `(p(X)-p(z))/(X-z)` and `p(z)`. /// See https://en.wikipedia.org/wiki/Horner%27s_method pub(crate) fn divide_by_linear(&self, z: F) -> (PolynomialCoeffs, F) { @@ -187,35 +129,7 @@ mod tests { use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; - use crate::polynomial::polynomial::PolynomialCoeffs; - - #[test] - fn zero_div_z_h() { - type F = GoldilocksField; - let zero = PolynomialCoeffs::::zero(16); - let quotient = zero.divide_by_z_h(4); - assert_eq!(quotient, zero); - } - - #[test] - fn division_by_z_h() { - type F = GoldilocksField; - let zero = F::ZERO; - let three = F::from_canonical_u64(3); - let four = F::from_canonical_u64(4); - let five = F::from_canonical_u64(5); - let six = F::from_canonical_u64(6); - - // a(x) = Z_4(x) q(x), where - // a(x) = 3 x^7 + 4 x^6 + 5 x^5 + 6 x^4 - 3 x^3 - 4 x^2 - 5 x - 6 - // Z_4(x) = x^4 - 1 - // q(x) = 3 x^3 + 4 x^2 + 5 x + 6 - let a = PolynomialCoeffs::new(vec![-six, -five, -four, -three, six, five, four, three]); - let q = PolynomialCoeffs::new(vec![six, five, four, three, zero, zero, zero, zero]); - - let computed_q = a.divide_by_z_h(4); - assert_eq!(computed_q, q); - } + use crate::polynomial::PolynomialCoeffs; #[test] #[ignore] diff --git a/src/polynomial/mod.rs b/src/polynomial/mod.rs index 2c7f7076..1a7b90fe 100644 --- a/src/polynomial/mod.rs +++ b/src/polynomial/mod.rs @@ -1,2 +1,616 @@ pub(crate) mod division; -pub mod polynomial; + +use std::cmp::max; +use std::iter::Sum; +use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; + +use anyhow::{ensure, Result}; +use serde::{Deserialize, Serialize}; + +use crate::field::extension_field::{Extendable, FieldExtension}; +use crate::field::fft::{fft, fft_with_options, ifft, FftRootTable}; +use crate::field::field_types::Field; +use crate::util::log2_strict; + +/// A polynomial in point-value form. +/// +/// The points are implicitly `g^i`, where `g` generates the subgroup whose size equals the number +/// of points. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct PolynomialValues { + pub values: Vec, +} + +impl PolynomialValues { + pub fn new(values: Vec) -> Self { + PolynomialValues { values } + } + + /// The number of values stored. + pub(crate) fn len(&self) -> usize { + self.values.len() + } + + pub fn ifft(&self) -> PolynomialCoeffs { + ifft(self) + } + + /// Returns the polynomial whose evaluation on the coset `shift*H` is `self`. + pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs { + let mut shifted_coeffs = self.ifft(); + shifted_coeffs + .coeffs + .iter_mut() + .zip(shift.inverse().powers()) + .for_each(|(c, r)| { + *c *= r; + }); + shifted_coeffs + } + + pub fn lde_multiple(polys: Vec, rate_bits: usize) -> Vec { + polys.into_iter().map(|p| p.lde(rate_bits)).collect() + } + + pub fn lde(&self, rate_bits: usize) -> Self { + let coeffs = ifft(self).lde(rate_bits); + fft_with_options(&coeffs, Some(rate_bits), None) + } + + pub fn degree(&self) -> usize { + self.degree_plus_one() + .checked_sub(1) + .expect("deg(0) is undefined") + } + + pub fn degree_plus_one(&self) -> usize { + self.ifft().degree_plus_one() + } +} + +impl From> for PolynomialValues { + fn from(values: Vec) -> Self { + Self::new(values) + } +} + +/// A polynomial in coefficient form. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct PolynomialCoeffs { + pub(crate) coeffs: Vec, +} + +impl PolynomialCoeffs { + pub fn new(coeffs: Vec) -> Self { + PolynomialCoeffs { coeffs } + } + + pub(crate) fn empty() -> Self { + Self::new(Vec::new()) + } + + pub(crate) fn zero(len: usize) -> Self { + Self::new(vec![F::ZERO; len]) + } + + pub(crate) fn is_zero(&self) -> bool { + self.coeffs.iter().all(|x| x.is_zero()) + } + + /// The number of coefficients. This does not filter out any zero coefficients, so it is not + /// necessarily related to the degree. + pub fn len(&self) -> usize { + self.coeffs.len() + } + + pub fn log_len(&self) -> usize { + log2_strict(self.len()) + } + + pub(crate) fn chunks(&self, chunk_size: usize) -> Vec { + self.coeffs + .chunks(chunk_size) + .map(|chunk| PolynomialCoeffs::new(chunk.to_vec())) + .collect() + } + + pub fn eval(&self, x: F) -> F { + self.coeffs + .iter() + .rev() + .fold(F::ZERO, |acc, &c| acc * x + c) + } + + /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. + pub fn eval_with_powers(&self, powers: &[F]) -> F { + debug_assert_eq!(self.coeffs.len(), powers.len() + 1); + let acc = self.coeffs[0]; + self.coeffs[1..] + .iter() + .zip(powers) + .fold(acc, |acc, (&x, &c)| acc + c * x) + } + + pub fn eval_base(&self, x: F::BaseField) -> F + where + F: FieldExtension, + { + self.coeffs + .iter() + .rev() + .fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c) + } + + /// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1. + pub fn eval_base_with_powers(&self, powers: &[F::BaseField]) -> F + where + F: FieldExtension, + { + debug_assert_eq!(self.coeffs.len(), powers.len() + 1); + let acc = self.coeffs[0]; + self.coeffs[1..] + .iter() + .zip(powers) + .fold(acc, |acc, (&x, &c)| acc + x.scalar_mul(c)) + } + + pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec { + polys.into_iter().map(|p| p.lde(rate_bits)).collect() + } + + pub fn lde(&self, rate_bits: usize) -> Self { + self.padded(self.len() << rate_bits) + } + + pub(crate) fn pad(&mut self, new_len: usize) -> Result<()> { + ensure!( + new_len >= self.len(), + "Trying to pad a polynomial of length {} to a length of {}.", + self.len(), + new_len + ); + self.coeffs.resize(new_len, F::ZERO); + Ok(()) + } + + pub(crate) fn padded(&self, new_len: usize) -> Self { + let mut poly = self.clone(); + poly.pad(new_len).unwrap(); + poly + } + + /// Removes leading zero coefficients. + pub fn trim(&mut self) { + self.coeffs.truncate(self.degree_plus_one()); + } + + /// Removes leading zero coefficients. + pub fn trimmed(&self) -> Self { + let coeffs = self.coeffs[..self.degree_plus_one()].to_vec(); + Self { coeffs } + } + + /// Degree of the polynomial + 1, or 0 for a polynomial with no non-zero coefficients. + pub(crate) fn degree_plus_one(&self) -> usize { + (0usize..self.len()) + .rev() + .find(|&i| self.coeffs[i].is_nonzero()) + .map_or(0, |i| i + 1) + } + + /// Leading coefficient. + pub fn lead(&self) -> F { + self.coeffs + .iter() + .rev() + .find(|x| x.is_nonzero()) + .map_or(F::ZERO, |x| *x) + } + + /// Reverse the order of the coefficients, not taking into account the leading zero coefficients. + pub(crate) fn rev(&self) -> Self { + Self::new(self.trimmed().coeffs.into_iter().rev().collect()) + } + + pub fn fft(&self) -> PolynomialValues { + fft(self) + } + + pub fn fft_with_options( + &self, + zero_factor: Option, + root_table: Option<&FftRootTable>, + ) -> PolynomialValues { + fft_with_options(self, zero_factor, root_table) + } + + /// Returns the evaluation of the polynomial on the coset `shift*H`. + pub fn coset_fft(&self, shift: F) -> PolynomialValues { + self.coset_fft_with_options(shift, None, None) + } + + /// Returns the evaluation of the polynomial on the coset `shift*H`. + pub fn coset_fft_with_options( + &self, + shift: F, + zero_factor: Option, + root_table: Option<&FftRootTable>, + ) -> PolynomialValues { + let modified_poly: Self = shift + .powers() + .zip(&self.coeffs) + .map(|(r, &c)| r * c) + .collect::>() + .into(); + modified_poly.fft_with_options(zero_factor, root_table) + } + + pub fn to_extension(&self) -> PolynomialCoeffs + where + F: Extendable, + { + PolynomialCoeffs::new(self.coeffs.iter().map(|&c| c.into()).collect()) + } + + pub fn mul_extension(&self, rhs: F::Extension) -> PolynomialCoeffs + where + F: Extendable, + { + PolynomialCoeffs::new(self.coeffs.iter().map(|&c| rhs.scalar_mul(c)).collect()) + } +} + +impl PartialEq for PolynomialCoeffs { + fn eq(&self, other: &Self) -> bool { + let max_terms = self.coeffs.len().max(other.coeffs.len()); + for i in 0..max_terms { + let self_i = self.coeffs.get(i).cloned().unwrap_or(F::ZERO); + let other_i = other.coeffs.get(i).cloned().unwrap_or(F::ZERO); + if self_i != other_i { + return false; + } + } + true + } +} + +impl Eq for PolynomialCoeffs {} + +impl From> for PolynomialCoeffs { + fn from(coeffs: Vec) -> Self { + Self::new(coeffs) + } +} + +impl Add for &PolynomialCoeffs { + type Output = PolynomialCoeffs; + + fn add(self, rhs: Self) -> Self::Output { + let len = max(self.len(), rhs.len()); + let a = self.padded(len).coeffs; + let b = rhs.padded(len).coeffs; + let coeffs = a.into_iter().zip(b).map(|(x, y)| x + y).collect(); + PolynomialCoeffs::new(coeffs) + } +} + +impl Sum for PolynomialCoeffs { + fn sum>(iter: I) -> Self { + iter.fold(Self::empty(), |acc, p| &acc + &p) + } +} + +impl Sub for &PolynomialCoeffs { + type Output = PolynomialCoeffs; + + fn sub(self, rhs: Self) -> Self::Output { + let len = max(self.len(), rhs.len()); + let mut coeffs = self.padded(len).coeffs; + for (i, &c) in rhs.coeffs.iter().enumerate() { + coeffs[i] -= c; + } + PolynomialCoeffs::new(coeffs) + } +} + +impl AddAssign for PolynomialCoeffs { + fn add_assign(&mut self, rhs: Self) { + let len = max(self.len(), rhs.len()); + self.coeffs.resize(len, F::ZERO); + for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) { + *l += r; + } + } +} + +impl AddAssign<&Self> for PolynomialCoeffs { + fn add_assign(&mut self, rhs: &Self) { + let len = max(self.len(), rhs.len()); + self.coeffs.resize(len, F::ZERO); + for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) { + *l += r; + } + } +} + +impl SubAssign for PolynomialCoeffs { + fn sub_assign(&mut self, rhs: Self) { + let len = max(self.len(), rhs.len()); + self.coeffs.resize(len, F::ZERO); + for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) { + *l -= r; + } + } +} + +impl SubAssign<&Self> for PolynomialCoeffs { + fn sub_assign(&mut self, rhs: &Self) { + let len = max(self.len(), rhs.len()); + self.coeffs.resize(len, F::ZERO); + for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) { + *l -= r; + } + } +} + +impl Mul for &PolynomialCoeffs { + type Output = PolynomialCoeffs; + + fn mul(self, rhs: F) -> Self::Output { + let coeffs = self.coeffs.iter().map(|&x| rhs * x).collect(); + PolynomialCoeffs::new(coeffs) + } +} + +impl MulAssign for PolynomialCoeffs { + fn mul_assign(&mut self, rhs: F) { + self.coeffs.iter_mut().for_each(|x| *x *= rhs); + } +} + +impl Mul for &PolynomialCoeffs { + type Output = PolynomialCoeffs; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn mul(self, rhs: Self) -> Self::Output { + let new_len = (self.len() + rhs.len()).next_power_of_two(); + let a = self.padded(new_len); + let b = rhs.padded(new_len); + let a_evals = a.fft(); + let b_evals = b.fft(); + + let mul_evals: Vec = a_evals + .values + .into_iter() + .zip(b_evals.values) + .map(|(pa, pb)| pa * pb) + .collect(); + ifft(&mul_evals.into()) + } +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use rand::{thread_rng, Rng}; + + use super::*; + use crate::field::goldilocks_field::GoldilocksField; + + #[test] + fn test_trimmed() { + type F = GoldilocksField; + + assert_eq!( + PolynomialCoeffs:: { coeffs: vec![] }.trimmed(), + PolynomialCoeffs:: { coeffs: vec![] } + ); + assert_eq!( + PolynomialCoeffs:: { + coeffs: vec![F::ZERO] + } + .trimmed(), + PolynomialCoeffs:: { coeffs: vec![] } + ); + assert_eq!( + PolynomialCoeffs:: { + coeffs: vec![F::ONE, F::TWO, F::ZERO, F::ZERO] + } + .trimmed(), + PolynomialCoeffs:: { + coeffs: vec![F::ONE, F::TWO] + } + ); + } + + #[test] + fn test_coset_fft() { + type F = GoldilocksField; + + let k = 8; + let n = 1 << k; + let poly = PolynomialCoeffs::new(F::rand_vec(n)); + let shift = F::rand(); + let coset_evals = poly.coset_fft(shift).values; + + let generator = F::primitive_root_of_unity(k); + let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) + .into_iter() + .map(|x| poly.eval(x)) + .collect::>(); + assert_eq!(coset_evals, naive_coset_evals); + + let ifft_coeffs = PolynomialValues::new(coset_evals).coset_ifft(shift); + assert_eq!(poly, ifft_coeffs); + } + + #[test] + fn test_coset_ifft() { + type F = GoldilocksField; + + let k = 8; + let n = 1 << k; + let evals = PolynomialValues::new(F::rand_vec(n)); + let shift = F::rand(); + let coeffs = evals.coset_ifft(shift); + + let generator = F::primitive_root_of_unity(k); + let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) + .into_iter() + .map(|x| coeffs.eval(x)) + .collect::>(); + assert_eq!(evals, naive_coset_evals.into()); + + let fft_evals = coeffs.coset_fft(shift); + assert_eq!(evals, fft_evals); + } + + #[test] + fn test_polynomial_multiplication() { + type F = GoldilocksField; + let mut rng = thread_rng(); + let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000)); + let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); + let b = PolynomialCoeffs::new(F::rand_vec(b_deg)); + let m1 = &a * &b; + let m2 = &a * &b; + for _ in 0..1000 { + let x = F::rand(); + assert_eq!(m1.eval(x), a.eval(x) * b.eval(x)); + assert_eq!(m2.eval(x), a.eval(x) * b.eval(x)); + } + } + + #[test] + fn test_inv_mod_xn() { + type F = GoldilocksField; + let mut rng = thread_rng(); + let a_deg = rng.gen_range(1..1_000); + let n = rng.gen_range(1..1_000); + let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); + let b = a.inv_mod_xn(n); + let mut m = &a * &b; + m.coeffs.drain(n..); + m.trim(); + assert_eq!( + m, + PolynomialCoeffs::new(vec![F::ONE]), + "a: {:#?}, b:{:#?}, n:{:#?}, m:{:#?}", + a, + b, + n, + m + ); + } + + #[test] + fn test_polynomial_long_division() { + type F = GoldilocksField; + let mut rng = thread_rng(); + let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000)); + let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); + let b = PolynomialCoeffs::new(F::rand_vec(b_deg)); + let (q, r) = a.div_rem_long_division(&b); + for _ in 0..1000 { + let x = F::rand(); + assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x)); + } + } + + #[test] + fn test_polynomial_division() { + type F = GoldilocksField; + let mut rng = thread_rng(); + let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000)); + let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); + let b = PolynomialCoeffs::new(F::rand_vec(b_deg)); + let (q, r) = a.div_rem(&b); + for _ in 0..1000 { + let x = F::rand(); + assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x)); + } + } + + #[test] + fn test_polynomial_division_by_constant() { + type F = GoldilocksField; + let mut rng = thread_rng(); + let a_deg = rng.gen_range(1..10_000); + let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); + let b = PolynomialCoeffs::from(vec![F::rand()]); + let (q, r) = a.div_rem(&b); + for _ in 0..1000 { + let x = F::rand(); + assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x)); + } + } + + // Test to see which polynomial division method is faster for divisions of the type + // `(X^n - 1)/(X - a) + #[test] + fn test_division_linear() { + type F = GoldilocksField; + let mut rng = thread_rng(); + let l = 14; + let n = 1 << l; + let g = F::primitive_root_of_unity(l); + let xn_minus_one = { + let mut xn_min_one_vec = vec![F::ZERO; n + 1]; + xn_min_one_vec[n] = F::ONE; + xn_min_one_vec[0] = F::NEG_ONE; + PolynomialCoeffs::new(xn_min_one_vec) + }; + + let a = g.exp_u64(rng.gen_range(0..(n as u64))); + let denom = PolynomialCoeffs::new(vec![-a, F::ONE]); + let now = Instant::now(); + xn_minus_one.div_rem(&denom); + println!("Division time: {:?}", now.elapsed()); + let now = Instant::now(); + xn_minus_one.div_rem_long_division(&denom); + println!("Division time: {:?}", now.elapsed()); + } + + #[test] + fn eq() { + type F = GoldilocksField; + assert_eq!( + PolynomialCoeffs::::new(vec![]), + PolynomialCoeffs::new(vec![]) + ); + assert_eq!( + PolynomialCoeffs::::new(vec![F::ZERO]), + PolynomialCoeffs::new(vec![F::ZERO]) + ); + assert_eq!( + PolynomialCoeffs::::new(vec![]), + PolynomialCoeffs::new(vec![F::ZERO]) + ); + assert_eq!( + PolynomialCoeffs::::new(vec![F::ZERO]), + PolynomialCoeffs::new(vec![]) + ); + assert_eq!( + PolynomialCoeffs::::new(vec![F::ZERO]), + PolynomialCoeffs::new(vec![F::ZERO, F::ZERO]) + ); + assert_eq!( + PolynomialCoeffs::::new(vec![F::ONE]), + PolynomialCoeffs::new(vec![F::ONE, F::ZERO]) + ); + assert_ne!( + PolynomialCoeffs::::new(vec![]), + PolynomialCoeffs::new(vec![F::ONE]) + ); + assert_ne!( + PolynomialCoeffs::::new(vec![F::ZERO]), + PolynomialCoeffs::new(vec![F::ZERO, F::ONE]) + ); + assert_ne!( + PolynomialCoeffs::::new(vec![F::ZERO]), + PolynomialCoeffs::new(vec![F::ONE, F::ZERO]) + ); + } +} diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs deleted file mode 100644 index 107d7a7b..00000000 --- a/src/polynomial/polynomial.rs +++ /dev/null @@ -1,635 +0,0 @@ -use std::cmp::max; -use std::iter::Sum; -use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; - -use anyhow::{ensure, Result}; -use serde::{Deserialize, Serialize}; - -use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::fft::{fft, fft_with_options, ifft, FftRootTable}; -use crate::field::field_types::Field; -use crate::util::log2_strict; - -/// A polynomial in point-value form. -/// -/// The points are implicitly `g^i`, where `g` generates the subgroup whose size equals the number -/// of points. -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct PolynomialValues { - pub values: Vec, -} - -impl PolynomialValues { - pub fn new(values: Vec) -> Self { - PolynomialValues { values } - } - - pub(crate) fn zero(len: usize) -> Self { - Self::new(vec![F::ZERO; len]) - } - - /// The number of values stored. - pub(crate) fn len(&self) -> usize { - self.values.len() - } - - pub fn ifft(&self) -> PolynomialCoeffs { - ifft(self) - } - - /// Returns the polynomial whose evaluation on the coset `shift*H` is `self`. - pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs { - let mut shifted_coeffs = self.ifft(); - shifted_coeffs - .coeffs - .iter_mut() - .zip(shift.inverse().powers()) - .for_each(|(c, r)| { - *c *= r; - }); - shifted_coeffs - } - - pub fn lde_multiple(polys: Vec, rate_bits: usize) -> Vec { - polys.into_iter().map(|p| p.lde(rate_bits)).collect() - } - - pub fn lde(&self, rate_bits: usize) -> Self { - let coeffs = ifft(self).lde(rate_bits); - fft_with_options(&coeffs, Some(rate_bits), None) - } - - pub fn degree(&self) -> usize { - self.degree_plus_one() - .checked_sub(1) - .expect("deg(0) is undefined") - } - - pub fn degree_plus_one(&self) -> usize { - self.ifft().degree_plus_one() - } -} - -impl From> for PolynomialValues { - fn from(values: Vec) -> Self { - Self::new(values) - } -} - -/// A polynomial in coefficient form. -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "")] -pub struct PolynomialCoeffs { - pub(crate) coeffs: Vec, -} - -impl PolynomialCoeffs { - pub fn new(coeffs: Vec) -> Self { - PolynomialCoeffs { coeffs } - } - - /// Create a new polynomial with its coefficient list padded to the next power of two. - pub(crate) fn new_padded(mut coeffs: Vec) -> Self { - while !coeffs.len().is_power_of_two() { - coeffs.push(F::ZERO); - } - PolynomialCoeffs { coeffs } - } - - pub(crate) fn empty() -> Self { - Self::new(Vec::new()) - } - - pub(crate) fn zero(len: usize) -> Self { - Self::new(vec![F::ZERO; len]) - } - - pub(crate) fn one() -> Self { - Self::new(vec![F::ONE]) - } - - pub(crate) fn is_zero(&self) -> bool { - self.coeffs.iter().all(|x| x.is_zero()) - } - - /// The number of coefficients. This does not filter out any zero coefficients, so it is not - /// necessarily related to the degree. - pub fn len(&self) -> usize { - self.coeffs.len() - } - - pub fn log_len(&self) -> usize { - log2_strict(self.len()) - } - - pub(crate) fn chunks(&self, chunk_size: usize) -> Vec { - self.coeffs - .chunks(chunk_size) - .map(|chunk| PolynomialCoeffs::new(chunk.to_vec())) - .collect() - } - - pub fn eval(&self, x: F) -> F { - self.coeffs - .iter() - .rev() - .fold(F::ZERO, |acc, &c| acc * x + c) - } - - pub fn eval_base(&self, x: F::BaseField) -> F - where - F: FieldExtension, - { - self.coeffs - .iter() - .rev() - .fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c) - } - - pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec { - polys.into_iter().map(|p| p.lde(rate_bits)).collect() - } - - pub fn lde(&self, rate_bits: usize) -> Self { - self.padded(self.len() << rate_bits) - } - - pub(crate) fn pad(&mut self, new_len: usize) -> Result<()> { - ensure!( - new_len >= self.len(), - "Trying to pad a polynomial of length {} to a length of {}.", - self.len(), - new_len - ); - self.coeffs.resize(new_len, F::ZERO); - Ok(()) - } - - pub(crate) fn padded(&self, new_len: usize) -> Self { - let mut poly = self.clone(); - poly.pad(new_len).unwrap(); - poly - } - - /// Removes leading zero coefficients. - pub fn trim(&mut self) { - self.coeffs.truncate(self.degree_plus_one()); - } - - /// Removes leading zero coefficients. - pub fn trimmed(&self) -> Self { - let coeffs = self.coeffs[..self.degree_plus_one()].to_vec(); - Self { coeffs } - } - - /// Degree of the polynomial + 1, or 0 for a polynomial with no non-zero coefficients. - pub(crate) fn degree_plus_one(&self) -> usize { - (0usize..self.len()) - .rev() - .find(|&i| self.coeffs[i].is_nonzero()) - .map_or(0, |i| i + 1) - } - - /// Leading coefficient. - pub fn lead(&self) -> F { - self.coeffs - .iter() - .rev() - .find(|x| x.is_nonzero()) - .map_or(F::ZERO, |x| *x) - } - - /// Reverse the order of the coefficients, not taking into account the leading zero coefficients. - pub(crate) fn rev(&self) -> Self { - Self::new(self.trimmed().coeffs.into_iter().rev().collect()) - } - - pub fn fft(&self) -> PolynomialValues { - fft(self) - } - - pub fn fft_with_options( - &self, - zero_factor: Option, - root_table: Option<&FftRootTable>, - ) -> PolynomialValues { - fft_with_options(self, zero_factor, root_table) - } - - /// Returns the evaluation of the polynomial on the coset `shift*H`. - pub fn coset_fft(&self, shift: F) -> PolynomialValues { - self.coset_fft_with_options(shift, None, None) - } - - /// Returns the evaluation of the polynomial on the coset `shift*H`. - pub fn coset_fft_with_options( - &self, - shift: F, - zero_factor: Option, - root_table: Option<&FftRootTable>, - ) -> PolynomialValues { - let modified_poly: Self = shift - .powers() - .zip(&self.coeffs) - .map(|(r, &c)| r * c) - .collect::>() - .into(); - modified_poly.fft_with_options(zero_factor, root_table) - } - - pub fn to_extension(&self) -> PolynomialCoeffs - where - F: Extendable, - { - PolynomialCoeffs::new(self.coeffs.iter().map(|&c| c.into()).collect()) - } - - pub fn mul_extension(&self, rhs: F::Extension) -> PolynomialCoeffs - where - F: Extendable, - { - PolynomialCoeffs::new(self.coeffs.iter().map(|&c| rhs.scalar_mul(c)).collect()) - } -} - -impl PartialEq for PolynomialCoeffs { - fn eq(&self, other: &Self) -> bool { - let max_terms = self.coeffs.len().max(other.coeffs.len()); - for i in 0..max_terms { - let self_i = self.coeffs.get(i).cloned().unwrap_or(F::ZERO); - let other_i = other.coeffs.get(i).cloned().unwrap_or(F::ZERO); - if self_i != other_i { - return false; - } - } - true - } -} - -impl Eq for PolynomialCoeffs {} - -impl From> for PolynomialCoeffs { - fn from(coeffs: Vec) -> Self { - Self::new(coeffs) - } -} - -impl Add for &PolynomialCoeffs { - type Output = PolynomialCoeffs; - - fn add(self, rhs: Self) -> Self::Output { - let len = max(self.len(), rhs.len()); - let a = self.padded(len).coeffs; - let b = rhs.padded(len).coeffs; - let coeffs = a.into_iter().zip(b).map(|(x, y)| x + y).collect(); - PolynomialCoeffs::new(coeffs) - } -} - -impl Sum for PolynomialCoeffs { - fn sum>(iter: I) -> Self { - iter.fold(Self::empty(), |acc, p| &acc + &p) - } -} - -impl Sub for &PolynomialCoeffs { - type Output = PolynomialCoeffs; - - fn sub(self, rhs: Self) -> Self::Output { - let len = max(self.len(), rhs.len()); - let mut coeffs = self.padded(len).coeffs; - for (i, &c) in rhs.coeffs.iter().enumerate() { - coeffs[i] -= c; - } - PolynomialCoeffs::new(coeffs) - } -} - -impl AddAssign for PolynomialCoeffs { - fn add_assign(&mut self, rhs: Self) { - let len = max(self.len(), rhs.len()); - self.coeffs.resize(len, F::ZERO); - for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) { - *l += r; - } - } -} - -impl AddAssign<&Self> for PolynomialCoeffs { - fn add_assign(&mut self, rhs: &Self) { - let len = max(self.len(), rhs.len()); - self.coeffs.resize(len, F::ZERO); - for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) { - *l += r; - } - } -} - -impl SubAssign for PolynomialCoeffs { - fn sub_assign(&mut self, rhs: Self) { - let len = max(self.len(), rhs.len()); - self.coeffs.resize(len, F::ZERO); - for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) { - *l -= r; - } - } -} - -impl SubAssign<&Self> for PolynomialCoeffs { - fn sub_assign(&mut self, rhs: &Self) { - let len = max(self.len(), rhs.len()); - self.coeffs.resize(len, F::ZERO); - for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) { - *l -= r; - } - } -} - -impl Mul for &PolynomialCoeffs { - type Output = PolynomialCoeffs; - - fn mul(self, rhs: F) -> Self::Output { - let coeffs = self.coeffs.iter().map(|&x| rhs * x).collect(); - PolynomialCoeffs::new(coeffs) - } -} - -impl MulAssign for PolynomialCoeffs { - fn mul_assign(&mut self, rhs: F) { - self.coeffs.iter_mut().for_each(|x| *x *= rhs); - } -} - -impl Mul for &PolynomialCoeffs { - type Output = PolynomialCoeffs; - - #[allow(clippy::suspicious_arithmetic_impl)] - fn mul(self, rhs: Self) -> Self::Output { - let new_len = (self.len() + rhs.len()).next_power_of_two(); - let a = self.padded(new_len); - let b = rhs.padded(new_len); - let a_evals = a.fft(); - let b_evals = b.fft(); - - let mul_evals: Vec = a_evals - .values - .into_iter() - .zip(b_evals.values) - .map(|(pa, pb)| pa * pb) - .collect(); - ifft(&mul_evals.into()) - } -} - -#[cfg(test)] -mod tests { - use std::time::Instant; - - use rand::{thread_rng, Rng}; - - use super::*; - use crate::field::goldilocks_field::GoldilocksField; - - #[test] - fn test_trimmed() { - type F = GoldilocksField; - - assert_eq!( - PolynomialCoeffs:: { coeffs: vec![] }.trimmed(), - PolynomialCoeffs:: { coeffs: vec![] } - ); - assert_eq!( - PolynomialCoeffs:: { - coeffs: vec![F::ZERO] - } - .trimmed(), - PolynomialCoeffs:: { coeffs: vec![] } - ); - assert_eq!( - PolynomialCoeffs:: { - coeffs: vec![F::ONE, F::TWO, F::ZERO, F::ZERO] - } - .trimmed(), - PolynomialCoeffs:: { - coeffs: vec![F::ONE, F::TWO] - } - ); - } - - #[test] - fn test_coset_fft() { - type F = GoldilocksField; - - let k = 8; - let n = 1 << k; - let poly = PolynomialCoeffs::new(F::rand_vec(n)); - let shift = F::rand(); - let coset_evals = poly.coset_fft(shift).values; - - let generator = F::primitive_root_of_unity(k); - let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) - .into_iter() - .map(|x| poly.eval(x)) - .collect::>(); - assert_eq!(coset_evals, naive_coset_evals); - - let ifft_coeffs = PolynomialValues::new(coset_evals).coset_ifft(shift); - assert_eq!(poly, ifft_coeffs.into()); - } - - #[test] - fn test_coset_ifft() { - type F = GoldilocksField; - - let k = 8; - let n = 1 << k; - let evals = PolynomialValues::new(F::rand_vec(n)); - let shift = F::rand(); - let coeffs = evals.coset_ifft(shift); - - let generator = F::primitive_root_of_unity(k); - let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) - .into_iter() - .map(|x| coeffs.eval(x)) - .collect::>(); - assert_eq!(evals, naive_coset_evals.into()); - - let fft_evals = coeffs.coset_fft(shift); - assert_eq!(evals, fft_evals); - } - - #[test] - fn test_polynomial_multiplication() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000)); - let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); - let b = PolynomialCoeffs::new(F::rand_vec(b_deg)); - let m1 = &a * &b; - let m2 = &a * &b; - for _ in 0..1000 { - let x = F::rand(); - assert_eq!(m1.eval(x), a.eval(x) * b.eval(x)); - assert_eq!(m2.eval(x), a.eval(x) * b.eval(x)); - } - } - - #[test] - fn test_inv_mod_xn() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let a_deg = rng.gen_range(1..1_000); - let n = rng.gen_range(1..1_000); - let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); - let b = a.inv_mod_xn(n); - let mut m = &a * &b; - m.coeffs.drain(n..); - m.trim(); - assert_eq!( - m, - PolynomialCoeffs::new(vec![F::ONE]), - "a: {:#?}, b:{:#?}, n:{:#?}, m:{:#?}", - a, - b, - n, - m - ); - } - - #[test] - fn test_polynomial_long_division() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000)); - let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); - let b = PolynomialCoeffs::new(F::rand_vec(b_deg)); - let (q, r) = a.div_rem_long_division(&b); - for _ in 0..1000 { - let x = F::rand(); - assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x)); - } - } - - #[test] - fn test_polynomial_division() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000)); - let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); - let b = PolynomialCoeffs::new(F::rand_vec(b_deg)); - let (q, r) = a.div_rem(&b); - for _ in 0..1000 { - let x = F::rand(); - assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x)); - } - } - - #[test] - fn test_polynomial_division_by_constant() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let a_deg = rng.gen_range(1..10_000); - let a = PolynomialCoeffs::new(F::rand_vec(a_deg)); - let b = PolynomialCoeffs::from(vec![F::rand()]); - let (q, r) = a.div_rem(&b); - for _ in 0..1000 { - let x = F::rand(); - assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x)); - } - } - - #[test] - fn test_division_by_z_h() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let a_deg = rng.gen_range(1..10_000); - let n = rng.gen_range(1..a_deg); - let mut a = PolynomialCoeffs::new(F::rand_vec(a_deg)); - a.trim(); - let z_h = { - let mut z_h_vec = vec![F::ZERO; n + 1]; - z_h_vec[n] = F::ONE; - z_h_vec[0] = F::NEG_ONE; - PolynomialCoeffs::new(z_h_vec) - }; - let m = &a * &z_h; - let now = Instant::now(); - let mut a_test = m.divide_by_z_h(n); - a_test.trim(); - println!("Division time: {:?}", now.elapsed()); - assert_eq!(a, a_test); - } - - #[test] - fn divide_zero_poly_by_z_h() { - let zero_poly = PolynomialCoeffs::::empty(); - zero_poly.divide_by_z_h(16); - } - - // Test to see which polynomial division method is faster for divisions of the type - // `(X^n - 1)/(X - a) - #[test] - fn test_division_linear() { - type F = GoldilocksField; - let mut rng = thread_rng(); - let l = 14; - let n = 1 << l; - let g = F::primitive_root_of_unity(l); - let xn_minus_one = { - let mut xn_min_one_vec = vec![F::ZERO; n + 1]; - xn_min_one_vec[n] = F::ONE; - xn_min_one_vec[0] = F::NEG_ONE; - PolynomialCoeffs::new(xn_min_one_vec) - }; - - let a = g.exp_u64(rng.gen_range(0..(n as u64))); - let denom = PolynomialCoeffs::new(vec![-a, F::ONE]); - let now = Instant::now(); - xn_minus_one.div_rem(&denom); - println!("Division time: {:?}", now.elapsed()); - let now = Instant::now(); - xn_minus_one.div_rem_long_division(&denom); - println!("Division time: {:?}", now.elapsed()); - } - - #[test] - fn eq() { - type F = GoldilocksField; - assert_eq!( - PolynomialCoeffs::::new(vec![]), - PolynomialCoeffs::new(vec![]) - ); - assert_eq!( - PolynomialCoeffs::::new(vec![F::ZERO]), - PolynomialCoeffs::new(vec![F::ZERO]) - ); - assert_eq!( - PolynomialCoeffs::::new(vec![]), - PolynomialCoeffs::new(vec![F::ZERO]) - ); - assert_eq!( - PolynomialCoeffs::::new(vec![F::ZERO]), - PolynomialCoeffs::new(vec![]) - ); - assert_eq!( - PolynomialCoeffs::::new(vec![F::ZERO]), - PolynomialCoeffs::new(vec![F::ZERO, F::ZERO]) - ); - assert_eq!( - PolynomialCoeffs::::new(vec![F::ONE]), - PolynomialCoeffs::new(vec![F::ONE, F::ZERO]) - ); - assert_ne!( - PolynomialCoeffs::::new(vec![]), - PolynomialCoeffs::new(vec![F::ONE]) - ); - assert_ne!( - PolynomialCoeffs::::new(vec![F::ZERO]), - PolynomialCoeffs::new(vec![F::ZERO, F::ONE]) - ); - assert_ne!( - PolynomialCoeffs::::new(vec![F::ZERO]), - PolynomialCoeffs::new(vec![F::ONE, F::ZERO]) - ); - } -} diff --git a/src/util/mod.rs b/src/util/mod.rs index 586033be..3f7c5dd1 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,7 +1,8 @@ -use core::hint::unreachable_unchecked; +use std::arch::asm; +use std::hint::unreachable_unchecked; use crate::field::field_types::Field; -use crate::polynomial::polynomial::PolynomialValues; +use crate::polynomial::PolynomialValues; pub(crate) mod bimap; pub(crate) mod context_tree; diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 432fb5c2..b5e805e9 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -1,123 +1,148 @@ -use std::iter::Product; -use std::ops::Sub; +use std::iter; + +use itertools::Itertools; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, RichField}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; +pub(crate) fn quotient_chunk_products( + quotient_values: &[F], + max_degree: usize, +) -> Vec { + debug_assert!(max_degree > 1); + assert!(!quotient_values.is_empty()); + let chunk_size = max_degree; + quotient_values + .chunks(chunk_size) + .map(|chunk| chunk.iter().copied().product()) + .collect() +} + /// Compute partial products of the original vector `v` such that all products consist of `max_degree` /// or less elements. This is done until we've computed the product `P` of all elements in the vector. -pub fn partial_products(v: &[T], max_degree: usize) -> Vec { +pub(crate) fn partial_products_and_z_gx(z_x: F, quotient_chunk_products: &[F]) -> Vec { + assert!(!quotient_chunk_products.is_empty()); let mut res = Vec::new(); - let mut remainder = v.to_vec(); - while remainder.len() > max_degree { - let new_partials = remainder - .chunks(max_degree) - // TODO: can filter out chunks of length 1. - .map(|chunk| chunk.iter().copied().product()) - .collect::>(); - res.extend_from_slice(&new_partials); - remainder = new_partials; + let mut acc = z_x; + for "ient_chunk_product in quotient_chunk_products { + acc *= quotient_chunk_product; + res.push(acc); } - res } /// Returns a tuple `(a,b)`, where `a` is the length of the output of `partial_products()` on a -/// vector of length `n`, and `b` is the number of elements needed to compute the final product. -pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { +/// vector of length `n`, and `b` is the number of original elements consumed in `partial_products()`. +pub(crate) fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { debug_assert!(max_degree > 1); - let mut res = 0; - let mut remainder = n; - while remainder > max_degree { - let new_partials_len = ceil_div_usize(remainder, max_degree); - res += new_partials_len; - remainder = new_partials_len; - } - - (res, remainder) + let chunk_size = max_degree; + // We'll split the product into `ceil_div_usize(n, chunk_size)` chunks, but the last chunk will + // be associated with Z(gx) itself. Thus we subtract one to get the chunks associated with + // partial products. + let num_chunks = ceil_div_usize(n, chunk_size) - 1; + (num_chunks, num_chunks * chunk_size) } -/// Checks that the partial products of `v` are coherent with those in `partials` by only computing -/// products of size `max_degree` or less. -pub fn check_partial_products>( - v: &[T], - mut partials: &[T], +/// Checks the relationship between each pair of partial product accumulators. In particular, this +/// sequence of accumulators starts with `Z(x)`, then contains each partial product polynomials +/// `p_i(x)`, and finally `Z(g x)`. See the partial products section of the Plonky2 paper. +pub(crate) fn check_partial_products( + numerators: &[F], + denominators: &[F], + partials: &[F], + z_x: F, + z_gx: F, max_degree: usize, -) -> Vec { - let mut res = Vec::new(); - let mut remainder = v; - while remainder.len() > max_degree { - let products = remainder - .chunks(max_degree) - .map(|chunk| chunk.iter().copied().product::()); - let products_len = products.len(); - res.extend(products.zip(partials).map(|(a, &b)| a - b)); - (remainder, partials) = partials.split_at(products_len); - } - - res +) -> Vec { + debug_assert!(max_degree > 1); + let product_accs = iter::once(&z_x) + .chain(partials.iter()) + .chain(iter::once(&z_gx)); + let chunk_size = max_degree; + numerators + .chunks(chunk_size) + .zip_eq(denominators.chunks(chunk_size)) + .zip_eq(product_accs.tuple_windows()) + .map(|((nume_chunk, deno_chunk), (&prev_acc, &next_acc))| { + let num_chunk_product = nume_chunk.iter().copied().product(); + let den_chunk_product = deno_chunk.iter().copied().product(); + // Assert that next_acc * deno_product = prev_acc * nume_product. + prev_acc * num_chunk_product - next_acc * den_chunk_product + }) + .collect() } -pub fn check_partial_products_recursively, const D: usize>( +/// Checks the relationship between each pair of partial product accumulators. In particular, this +/// sequence of accumulators starts with `Z(x)`, then contains each partial product polynomials +/// `p_i(x)`, and finally `Z(g x)`. See the partial products section of the Plonky2 paper. +pub(crate) fn check_partial_products_recursively, const D: usize>( builder: &mut CircuitBuilder, - v: &[ExtensionTarget], + numerators: &[ExtensionTarget], + denominators: &[ExtensionTarget], partials: &[ExtensionTarget], + z_x: ExtensionTarget, + z_gx: ExtensionTarget, max_degree: usize, ) -> Vec> { - let mut res = Vec::new(); - let mut remainder = v.to_vec(); - let mut partials = partials.to_vec(); - while remainder.len() > max_degree { - let products = remainder - .chunks(max_degree) - .map(|chunk| builder.mul_many_extension(chunk)) - .collect::>(); - res.extend( - products - .iter() - .zip(&partials) - .map(|(&a, &b)| builder.sub_extension(a, b)), - ); - remainder = partials.drain(..products.len()).collect(); - } - - res + debug_assert!(max_degree > 1); + let product_accs = iter::once(&z_x) + .chain(partials.iter()) + .chain(iter::once(&z_gx)); + let chunk_size = max_degree; + numerators + .chunks(chunk_size) + .zip_eq(denominators.chunks(chunk_size)) + .zip_eq(product_accs.tuple_windows()) + .map(|((nume_chunk, deno_chunk), (&prev_acc, &next_acc))| { + let nume_product = builder.mul_many_extension(nume_chunk); + let deno_product = builder.mul_many_extension(deno_chunk); + let next_acc_deno = builder.mul_extension(next_acc, deno_product); + // Assert that next_acc * deno_product = prev_acc * nume_product. + builder.mul_sub_extension(prev_acc, nume_product, next_acc_deno) + }) + .collect() } #[cfg(test)] mod tests { - use num::Zero; - use super::*; + use crate::field::goldilocks_field::GoldilocksField; #[test] fn test_partial_products() { - let v = vec![1, 2, 3, 4, 5, 6]; - let p = partial_products(&v, 2); - assert_eq!(p, vec![2, 12, 30, 24, 30]); - let nums = num_partial_products(v.len(), 2); - assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &p, 2) - .iter() - .all(|x| x.is_zero())); - assert_eq!( - v.into_iter().product::(), - p[p.len() - nums.1..].iter().copied().product(), - ); + type F = GoldilocksField; + let denominators = vec![F::ONE; 6]; + let z_x = F::ONE; + let v = field_vec(&[1, 2, 3, 4, 5, 6]); + let z_gx = F::from_canonical_u64(720); + let quotient_chunks_prods = quotient_chunk_products(&v, 2); + assert_eq!(quotient_chunks_prods, field_vec(&[2, 12, 30])); + let pps_and_z_gx = partial_products_and_z_gx(z_x, "ient_chunks_prods); + let pps = &pps_and_z_gx[..pps_and_z_gx.len() - 1]; + assert_eq!(pps_and_z_gx, field_vec(&[2, 24, 720])); - let v = vec![1, 2, 3, 4, 5, 6]; - let p = partial_products(&v, 3); - assert_eq!(p, vec![6, 120]); - let nums = num_partial_products(v.len(), 3); - assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &p, 3) + let nums = num_partial_products(v.len(), 2); + assert_eq!(pps.len(), nums.0); + assert!(check_partial_products(&v, &denominators, pps, z_x, z_gx, 2) .iter() .all(|x| x.is_zero())); - assert_eq!( - v.into_iter().product::(), - p[p.len() - nums.1..].iter().copied().product(), - ); + + let quotient_chunks_prods = quotient_chunk_products(&v, 3); + assert_eq!(quotient_chunks_prods, field_vec(&[6, 120])); + let pps_and_z_gx = partial_products_and_z_gx(z_x, "ient_chunks_prods); + let pps = &pps_and_z_gx[..pps_and_z_gx.len() - 1]; + assert_eq!(pps_and_z_gx, field_vec(&[6, 720])); + let nums = num_partial_products(v.len(), 3); + assert_eq!(pps.len(), nums.0); + assert!(check_partial_products(&v, &denominators, pps, z_x, z_gx, 3) + .iter() + .all(|x| x.is_zero())); + } + + fn field_vec(xs: &[usize]) -> Vec { + xs.iter().map(|&x| F::from_canonical_usize(x)).collect() } } diff --git a/src/util/reducing.rs b/src/util/reducing.rs index 28acc1b0..2b8e80e7 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -2,11 +2,14 @@ use std::borrow::Borrow; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, RichField}; +use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::field::field_types::Field; use crate::gates::reducing::ReducingGate; +use crate::gates::reducing_extension::ReducingExtensionGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::polynomial::polynomial::PolynomialCoeffs; +use crate::polynomial::PolynomialCoeffs; /// When verifying the composition polynomial in FRI we have to compute sums of the form /// `(sum_0^k a^i * x_i)/d_0 + (sum_k^r a^i * y_i)/d_1` @@ -93,7 +96,7 @@ impl ReducingFactorTarget { Self { base, count: 0 } } - /// Reduces a length `n` vector of `Target`s using `n/21` `ReducingGate`s (with 33 routed wires and 126 wires). + /// Reduces a vector of `Target`s using `ReducingGate`s. pub fn reduce_base( &mut self, terms: &[Target], @@ -102,11 +105,22 @@ impl ReducingFactorTarget { where F: Extendable, { + let l = terms.len(); + + // For small reductions, use an arithmetic gate. + if l <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops + 1 { + let terms_ext = terms + .iter() + .map(|&t| builder.convert_to_ext(t)) + .collect::>(); + return self.reduce_arithmetic(&terms_ext, builder); + } + let max_coeffs_len = ReducingGate::::max_coeffs_len( builder.config.num_wires, builder.config.num_routed_wires, ); - self.count += terms.len() as u64; + self.count += l as u64; let zero = builder.zero(); let zero_ext = builder.zero_extension(); let mut acc = zero_ext; @@ -137,6 +151,7 @@ impl ReducingFactorTarget { acc } + /// Reduces a vector of `ExtensionTarget`s using `ReducingExtensionGate`s. pub fn reduce( &mut self, terms: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. @@ -146,18 +161,74 @@ impl ReducingFactorTarget { F: Extendable, { let l = terms.len(); - self.count += l as u64; - let mut terms_vec = terms.to_vec(); - let mut acc = builder.zero_extension(); - terms_vec.reverse(); - - for x in terms_vec { - acc = builder.mul_add_extension(self.base, acc, x); + // For small reductions, use an arithmetic gate. + if l <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops + 1 { + return self.reduce_arithmetic(terms, builder); } + + let max_coeffs_len = ReducingExtensionGate::::max_coeffs_len( + builder.config.num_wires, + builder.config.num_routed_wires, + ); + self.count += l as u64; + let zero_ext = builder.zero_extension(); + let mut acc = zero_ext; + let mut reversed_terms = terms.to_vec(); + while reversed_terms.len() % max_coeffs_len != 0 { + reversed_terms.push(zero_ext); + } + reversed_terms.reverse(); + for chunk in reversed_terms.chunks_exact(max_coeffs_len) { + let gate = ReducingExtensionGate::new(max_coeffs_len); + let gate_index = builder.add_gate(gate.clone(), Vec::new()); + + builder.connect_extension( + self.base, + ExtensionTarget::from_range(gate_index, ReducingExtensionGate::::wires_alpha()), + ); + builder.connect_extension( + acc, + ExtensionTarget::from_range( + gate_index, + ReducingExtensionGate::::wires_old_acc(), + ), + ); + for (i, &t) in chunk.iter().enumerate() { + builder.connect_extension( + t, + ExtensionTarget::from_range( + gate_index, + ReducingExtensionGate::::wires_coeff(i), + ), + ); + } + + acc = + ExtensionTarget::from_range(gate_index, ReducingExtensionGate::::wires_output()); + } + acc } + /// Reduces a vector of `ExtensionTarget`s using `ArithmeticGate`s. + fn reduce_arithmetic( + &mut self, + terms: &[ExtensionTarget], + builder: &mut CircuitBuilder, + ) -> ExtensionTarget + where + F: RichField + Extendable, + { + self.count += terms.len() as u64; + terms + .iter() + .rev() + .fold(builder.zero_extension(), |acc, &et| { + builder.mul_add_extension(self.base, acc, et) + }) + } + pub fn shift( &mut self, x: ExtensionTarget, @@ -261,4 +332,9 @@ mod tests { fn test_reduce_gadget_base_100() -> Result<()> { test_reduce_gadget_base(100) } + + #[test] + fn test_reduce_gadget_100() -> Result<()> { + test_reduce_gadget(100) + } } diff --git a/src/util/serialization.rs b/src/util/serialization.rs index 053b8c68..06f2a3e5 100644 --- a/src/util/serialization.rs +++ b/src/util/serialization.rs @@ -1,8 +1,6 @@ use std::collections::HashMap; -use std::convert::TryInto; use std::io::Cursor; use std::io::{Read, Result, Write}; -use std::iter::FromIterator; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::{PrimeField, RichField}; @@ -17,7 +15,7 @@ use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::proof::{ CompressedProof, CompressedProofWithPublicInputs, OpeningSet, Proof, ProofWithPublicInputs, }; -use crate::polynomial::polynomial::PolynomialCoeffs; +use crate::polynomial::PolynomialCoeffs; #[derive(Debug)] pub struct Buffer(Cursor>); diff --git a/src/util/timing.rs b/src/util/timing.rs index cd9ea731..4250d688 100644 --- a/src/util/timing.rs +++ b/src/util/timing.rs @@ -92,7 +92,7 @@ impl TimingTree { fn duration(&self) -> Duration { self.exit_time - .unwrap_or(Instant::now()) + .unwrap_or_else(Instant::now) .duration_since(self.enter_time) }